跳到主要内容

工具系统

工具是 AI Agent 与外部世界交互的桥梁。一个设计良好的工具系统能够让 Agent 突破 LLM 自身的限制,完成各种复杂任务。本章将深入介绍工具的设计原则、实现方式和最佳实践。

工具的本质

工具本质上是一个可以被 LLM 调用的函数。它接收结构化的参数输入,执行特定操作,返回结果。LLM 根据用户的请求和工具的描述,决定何时调用哪个工具,以及传递什么参数。

一个完整的工具定义包含以下要素:

  • 名称:工具的唯一标识符
  • 描述:说明工具的功能和使用场景
  • 参数模式:定义工具接受的参数类型和结构
  • 执行逻辑:实际执行操作的代码

工具设计原则

单一职责

每个工具应该只做一件事,并且把它做好。这有助于 LLM 正确理解工具的用途,也便于维护和测试。

from langchain.tools import tool

@tool
def search_web(query: str) -> str:
"""搜索互联网获取信息"""
pass

@tool
def get_weather(city: str) -> str:
"""获取指定城市的天气信息"""
pass

@tool
def send_email(to: str, subject: str, body: str) -> str:
"""发送电子邮件"""
pass

清晰的描述

工具的描述直接决定了 LLM 能否正确使用它。描述应该包含:

  • 工具的功能说明
  • 适用场景
  • 参数的含义和格式要求
@tool
def query_database(
query: str,
database: str = "main"
) -> str:
"""
执行 SQL 查询并返回结果。

适用场景:
- 需要从数据库获取数据时
- 需要统计或分析数据时

参数说明:
- query: SQL 查询语句,只支持 SELECT 语句
- database: 数据库名称,默认为 main

注意:不支持 INSERT、UPDATE、DELETE 等修改操作
"""
pass

健壮的错误处理

工具应该优雅地处理各种错误情况,返回有意义的错误信息:

@tool
def divide(a: float, b: float) -> str:
"""执行除法运算"""
try:
if b == 0:
return "错误:除数不能为零"
result = a / b
return f"结果:{result}"
except TypeError:
return "错误:参数必须是数字"
except Exception as e:
return f"计算错误:{str(e)}"

定义工具的方式

LangChain 1.0 提供了多种定义工具的方式。

使用 @tool 装饰器

最简单的方式是使用 @tool 装饰器,LangChain 会自动从函数签名和 docstring 中提取工具的描述和参数模式:

from langchain.tools import tool

@tool
def multiply(a: int, b: int) -> int:
"""将两个数字相乘"""
return a * b

@tool
def get_word_length(word: str) -> int:
"""返回单词的长度"""
return len(word)

tools = [multiply, get_word_length]

使用 Pydantic 定义参数模式

对于复杂的参数验证,可以使用 Pydantic:

from langchain.tools import tool
from pydantic import BaseModel, Field, field_validator
import re

class EmailInput(BaseModel):
"""邮件输入参数"""
to: str = Field(description="收件人邮箱地址")
subject: str = Field(description="邮件主题")
body: str = Field(description="邮件正文")

@field_validator('to')
@classmethod
def validate_email(cls, v):
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, v):
raise ValueError('无效的邮箱地址')
return v

@field_validator('subject')
@classmethod
def validate_subject(cls, v):
if len(v) > 100:
raise ValueError('邮件主题不能超过100个字符')
return v

@tool(args_schema=EmailInput)
def send_email(to: str, subject: str, body: str) -> str:
"""发送电子邮件"""
# 实际发送邮件的逻辑
return f"邮件已发送至 {to}"

使用 StructuredTool

对于更复杂的场景,可以使用 StructuredTool

from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field

class CalculatorInput(BaseModel):
"""计算器输入参数"""
expression: str = Field(description="数学表达式,如 '2 + 3 * 4'")
precision: int = Field(default=2, description="结果的小数位数")

def calculator_func(expression: str, precision: int = 2) -> str:
"""计算数学表达式"""
try:
result = eval(expression)
return f"{result:.{precision}f}"
except Exception as e:
return f"计算错误:{e}"

calculator = StructuredTool(
name="calculator",
description="计算数学表达式,支持加减乘除等基本运算",
func=calculator_func,
args_schema=CalculatorInput
)

异步工具

对于 I/O 密集型操作,可以使用异步工具:

@tool
async def async_search(query: str) -> str:
"""异步搜索互联网"""
import aiohttp

async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com/search?q={query}") as response:
data = await response.json()
return data.get("result", "未找到结果")

工具调用流程

当 Agent 决定使用工具时,执行流程如下:

直接调用工具

工具可以像普通函数一样直接调用:

from langchain_core.tools import tool

@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b

# 方式一:普通函数调用
result = multiply.invoke({"a": 42, "b": 7})
print(result) # 输出:294

# 方式二:通过 tool_call 格式调用
tool_call = {
"type": "tool_call",
"id": "1",
"args": {"a": 42, "b": 7}
}
result = multiply.invoke(tool_call)
# 返回 ToolMessage 对象

访问上下文信息

工具有时需要访问运行时上下文,如用户 ID、会话状态等。LangChain 1.0 提供了多种方式实现这一点。

通过 context 参数访问配置

在 LangChain 1.0 中,推荐使用 context 参数传递运行时配置:

from langchain.tools import tool
from langchain.agents import create_agent

@tool
def get_user_info(user_id: str) -> str:
"""获取当前用户信息"""
return f"用户 ID: {user_id}"

agent = create_agent(
model="openai:gpt-4o-mini",
tools=[get_user_info]
)

# 调用时传入配置
result = agent.invoke(
{"messages": [{"role": "user", "content": "获取用户信息"}]},
context={"user_id": "user-123"}
)

通过 Middleware 访问状态

使用 middleware 可以更灵活地管理状态访问:

from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.tools import tool
from typing import TypedDict

class CustomState(TypedDict):
messages: list
user_name: str

@tool
def get_user_name() -> str:
"""从状态中获取用户名"""
# 工具会通过 middleware 获取状态
pass

class StateInjectionMiddleware(AgentMiddleware):
"""状态注入中间件"""

def wrap_tool_call(self, request, handler):
"""包装工具调用,注入状态"""
# 可以在这里访问状态并传递给工具
return (yield handler(request))

agent = create_agent(
model="openai:gpt-4o-mini",
tools=[get_user_name],
state_schema=CustomState,
middleware=[StateInjectionMiddleware()]
)

更新状态

在 LangChain 1.0 中,推荐使用 middleware 来管理状态更新:

from langchain.agents import create_agent, Middleware
from langchain.tools import tool
from langchain.messages import ToolMessage

class StateUpdateMiddleware(AgentMiddleware):
"""状态更新中间件"""

def after_tool_call(self, request, result, state, context):
"""工具调用后更新状态"""
if request["name"] == "update_user_name":
# 更新状态
return {"user_name": result}
return {}

通过 Store 访问长期记忆

使用 get_store() 访问长期记忆存储:

from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_store
from langgraph.prebuilt import create_react_agent
from langgraph.store.memory import InMemoryStore

# 创建存储
store = InMemoryStore()

@tool
def save_memory(content: str, config: RunnableConfig) -> str:
"""保存信息到长期记忆"""
memory_store = get_store()
user_id = config.get("configurable", {}).get("user_id", "default")

memory_store.put(
("memories", user_id), # 命名空间
"memory_1", # 键
{"content": content} # 值
)
return f"已保存:{content}"

@tool
def recall_memories(config: RunnableConfig) -> str:
"""检索长期记忆"""
memory_store = get_store()
user_id = config.get("configurable", {}).get("user_id", "default")

memories = memory_store.search(("memories", user_id))
return "\n".join([m.value["content"] for m in memories])

agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=[save_memory, recall_memories],
store=store
)

常用工具实现

网络搜索工具

from langchain_core.tools import tool
import requests

@tool
def web_search(query: str, num_results: int = 5) -> str:
"""
搜索互联网获取信息。

参数:
- query: 搜索关键词
- num_results: 返回结果数量,默认5条

返回搜索结果的摘要列表。
"""
try:
# 这里使用示例,实际项目中替换为真实的搜索 API
# 如 Google Custom Search、Tavily、SerpAPI 等
return f"搜索 '{query}' 的结果:找到相关信息..."
except Exception as e:
return f"搜索失败:{str(e)}"

数据库查询工具

from langchain_core.tools import tool
import sqlite3

@tool
def query_sqlite(query: str, db_path: str = "database.db") -> str:
"""
执行 SQLite 数据库查询。

参数:
- query: SQL 查询语句(仅支持 SELECT)
- db_path: 数据库文件路径

返回查询结果的 JSON 字符串。
"""
# 安全检查:只允许 SELECT 语句
if not query.strip().upper().startswith("SELECT"):
return "错误:只支持 SELECT 查询"

try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(query)

columns = [description[0] for description in cursor.description]
rows = cursor.fetchall()

results = [dict(zip(columns, row)) for row in rows]
conn.close()

return str(results)
except Exception as e:
return f"查询错误:{str(e)}"

文件操作工具

from langchain_core.tools import tool
import os

@tool
def read_file(file_path: str) -> str:
"""
读取文件内容。

参数:
- file_path: 文件的完整路径

返回文件内容字符串。
"""
try:
if not os.path.exists(file_path):
return f"错误:文件不存在 {file_path}"

with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()

# 限制返回长度
return content[:10000] if len(content) > 10000 else content
except Exception as e:
return f"读取错误:{str(e)}"

@tool
def write_file(file_path: str, content: str) -> str:
"""
写入文件内容。

参数:
- file_path: 文件的完整路径
- content: 要写入的内容

返回操作结果。
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)

with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)

return f"成功写入文件:{file_path}"
except Exception as e:
return f"写入错误:{str(e)}"

@tool
def list_directory(dir_path: str) -> str:
"""
列出目录内容。

参数:
- dir_path: 目录路径

返回目录中的文件和子目录列表。
"""
try:
if not os.path.exists(dir_path):
return f"错误:目录不存在 {dir_path}"

items = os.listdir(dir_path)
result = []

for item in items:
full_path = os.path.join(dir_path, item)
if os.path.isdir(full_path):
result.append(f"[目录] {item}")
else:
size = os.path.getsize(full_path)
result.append(f"[文件] {item} ({size} bytes)")

return "\n".join(result)
except Exception as e:
return f"列出目录错误:{str(e)}"

HTTP 请求工具

from langchain_core.tools import tool
import requests
import json

@tool
def http_get(url: str, params: str = "{}") -> str:
"""
发送 HTTP GET 请求。

参数:
- url: 请求的 URL
- params: JSON 格式的查询参数,如 '{"key": "value"}'

返回响应内容。
"""
try:
params_dict = json.loads(params) if params else {}
response = requests.get(url, params=params_dict, timeout=30)
response.raise_for_status()
return response.text[:5000]
except json.JSONDecodeError:
return "错误:params 参数必须是有效的 JSON 格式"
except requests.exceptions.RequestException as e:
return f"请求错误:{str(e)}"

@tool
def http_post(url: str, data: str = "{}", headers: str = "{}") -> str:
"""
发送 HTTP POST 请求。

参数:
- url: 请求的 URL
- data: JSON 格式的请求体
- headers: JSON 格式的请求头

返回响应内容。
"""
try:
data_dict = json.loads(data) if data else {}
headers_dict = json.loads(headers) if headers else {}

response = requests.post(
url,
json=data_dict,
headers=headers_dict,
timeout=30
)
response.raise_for_status()
return response.text[:5000]
except json.JSONDecodeError:
return "错误:data 或 headers 参数必须是有效的 JSON 格式"
except requests.exceptions.RequestException as e:
return f"请求错误:{str(e)}"

工具高级特性

直接返回(return_direct)

使用 return_direct=True 可以让工具结果直接返回给用户,不再进行后续的 LLM 处理:

from langchain_core.tools import tool

@tool(return_direct=True)
def get_current_time() -> str:
"""获取当前时间(直接返回,不经过 LLM 处理)"""
from datetime import datetime
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

# 当调用这个工具时,结果会直接返回给用户
# 而不是再让 LLM 处理

强制工具调用

通过 tool_choice 参数可以强制模型使用特定的工具:

from langgraph.prebuilt import create_react_agent
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI

@tool
def greet(user_name: str) -> str:
"""向用户问好"""
return f"你好,{user_name}!"

llm = ChatOpenAI(model="gpt-4o-mini")
tools = [greet]

# 强制使用 greet 工具
agent = create_react_agent(
model=llm.bind_tools(tools, tool_choice={"type": "tool", "name": "greet"}),
tools=tools
)

禁用并行工具调用

某些场景下需要禁用并行工具调用:

from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent

llm = ChatOpenAI(model="gpt-4o-mini")

# 禁用并行工具调用
agent = create_react_agent(
model=llm.bind_tools(tools, parallel_tool_calls=False),
tools=tools
)

错误处理

使用 ToolNode 处理错误

ToolNode 提供了内置的错误处理机制:

from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool

@tool
def risky_operation(param: str) -> str:
"""可能失败的操作"""
if param == "error":
raise ValueError("操作失败")
return f"成功:{param}"

# 默认错误处理(捕获错误并返回错误消息)
tool_node = ToolNode([risky_operation])

# 自定义错误消息
tool_node = ToolNode(
[risky_operation],
handle_tool_errors="操作失败,请检查参数后重试。"
)

# 禁用错误处理(直接抛出异常)
tool_node = ToolNode([risky_operation], handle_tool_errors=False)

自定义错误处理函数

def custom_error_handler(error: Exception, tool_call: dict) -> str:
"""自定义错误处理函数"""
tool_name = tool_call.get("name", "unknown")
return f"工具 {tool_name} 执行失败:{str(error)}。请尝试其他方法。"

tool_node = ToolNode(
[risky_operation],
handle_tool_errors=custom_error_handler
)

工具安全

输入验证

使用 Pydantic 进行严格的输入验证:

from langchain_core.tools import tool
from pydantic import BaseModel, Field, field_validator
import re

class SecureInput(BaseModel):
"""安全输入验证"""
command: str = Field(description="要执行的命令")

@field_validator('command')
@classmethod
def validate_command(cls, v):
# 禁止危险命令
dangerous_commands = ['rm', 'del', 'format', 'shutdown']
if any(cmd in v.lower() for cmd in dangerous_commands):
raise ValueError('禁止执行危险命令')
return v

@tool(args_schema=SecureInput)
def execute_safe_command(command: str) -> str:
"""执行安全命令"""
# 实际执行逻辑
return f"执行命令:{command}"

敏感信息过滤

过滤工具返回中的敏感信息:

import re

def sanitize_output(output: str) -> str:
"""过滤敏感信息"""
patterns = [
(r'\b\d{16,19}\b', '[银行卡号]'),
(r'\b\d{17}[\dXx]\b', '[身份证号]'),
(r'\b[\w\.-]+@[\w\.-]+\.\w+\b', '[邮箱]'),
(r'\b1[3-9]\d{9}\b', '[手机号]'),
(r'(password|pwd|token|key|secret)["\']?\s*[:=]\s*["\']?[^\s"\']+', '[敏感信息]'),
]

for pattern, replacement in patterns:
output = re.sub(pattern, replacement, output, flags=re.IGNORECASE)

return output

@tool
def query_user_info(user_id: str) -> str:
"""查询用户信息"""
user_data = {
"id": user_id,
"name": "张三",
"phone": "13812345678",
"email": "[email protected]",
}
return sanitize_output(str(user_data))

工具测试

为工具编写单元测试:

import pytest
from unittest.mock import patch, MagicMock

def test_web_search():
"""测试网络搜索工具"""
with patch('requests.get') as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [
{"title": "测试结果", "link": "http://example.com", "snippet": "摘要"}
]
}
mock_get.return_value = mock_response

result = web_search.invoke({"query": "测试"})
assert "测试结果" in result

def test_calculator():
"""测试计算器工具"""
result = calculator.invoke({"expression": "2 + 3 * 4"})
assert "14" in result

def test_calculator_division_by_zero():
"""测试除零错误处理"""
result = divide.invoke({"a": 10, "b": 0})
assert "错误" in result

动态工具选择

当工具数量很多时,可以使用语义搜索动态选择相关工具:

from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

@tool
def get_weather(city: str) -> str:
"""获取天气信息"""
return f"{city}:晴"

@tool
def get_news(topic: str) -> str:
"""获取新闻"""
return f"{topic} 相关新闻"

@tool
def calculate(expression: str) -> str:
"""计算数学表达式"""
return str(eval(expression))

@tool
def translate(text: str, target_lang: str) -> str:
"""翻译文本"""
return f"[{target_lang}] {text}"

# 所有可用工具
all_tools = [get_weather, get_news, calculate, translate]

agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=all_tools
)

# 模型会根据用户输入自动选择合适的工具

使用预构建工具

LLM 提供商工具

一些 LLM 提供商提供了预构建的工具,可以直接使用:

from langgraph.prebuilt import create_react_agent

# 使用 OpenAI 的 web_search_preview 工具
agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=[{"type": "web_search_preview"}]
)

result = agent.invoke({
"messages": [{"role": "user", "content": "今天有什么科技新闻?"}]
})

LangChain 工具集成

LangChain 提供了大量预构建的工具集成,包括:

  • 搜索:Tavily、SerpAPI、Bing
  • 代码执行:Python REPL
  • 数据库:SQL、MongoDB
  • Web 数据:Web scraping
  • APIs:OpenWeatherMap、NewsAPI
from langchain_community.tools import TavilySearchResults

# 使用 Tavily 搜索工具
search_tool = TavilySearchResults(max_results=3)

agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=[search_tool]
)

小结

工具系统是 AI Agent 的核心组件,设计良好的工具能够极大地扩展 Agent 的能力:

  • 遵循单一职责原则,每个工具只做一件事
  • 提供清晰的描述,帮助 LLM 正确使用工具
  • 实现健壮的错误处理,返回有意义的错误信息
  • 使用 InjectedState 和 Store 访问上下文信息
  • 注意安全性,验证输入、过滤敏感信息
  • 为工具编写测试,确保功能正确可靠
  • 使用 ToolNode 的 handle_tool_errors 参数处理工具错误

下一章我们将学习记忆系统的设计与实现。

参考资料