工具系统
工具是 AI Agent 与外部世界交互的桥梁。一个设计良好的工具系统能够让 Agent 突破 LLM 自身的限制,完成各种复杂任务。本章将深入介绍工具的设计原则、实现方式和最佳实践。
工具的本质
工具本质上是一个可以被 LLM 调用的函数。它接收结构化的参数输入,执行特定操作,返回结果。LLM 根据用户的请求和工具的描述,决定何时调用哪个工具,以及传递什么参数。
一个完整的工具定义包含以下要素:
- 名称:工具的唯一标识符
- 描述:说明工具的功能和使用场景
- 参数模式:定义工具接受的参数类型和结构
- 执行逻辑:实际执行操作的代码
工具设计原则
单一职责
每个工具应该只做一件事,并且把它做好。这有助于 LLM 正确理解工具的用途,也便于维护和测试。
from langchain.tools import tool
@tool
def search_web(query: str) -> str:
"""搜索互联网获取信息"""
pass
@tool
def get_weather(city: str) -> str:
"""获取指定城市的天气信息"""
pass
@tool
def send_email(to: str, subject: str, body: str) -> str:
"""发送电子邮件"""
pass
清晰的描述
工具的描述直接决定了 LLM 能否正确使用它。描述应该包含:
- 工具的功能说明
- 适用场景
- 参数的含义和格式要求
@tool
def query_database(
query: str,
database: str = "main"
) -> str:
"""
执行 SQL 查询并返回结果。
适用场景:
- 需要从数据库获取数据时
- 需要统计或分析数据时
参数说明:
- query: SQL 查询语句,只支持 SELECT 语句
- database: 数据库名称,默认为 main
注意:不支持 INSERT、UPDATE、DELETE 等修改操作
"""
pass
健壮的错误处理
工具应该优雅地处理各种错误情况,返回有意义的错误信息:
@tool
def divide(a: float, b: float) -> str:
"""执行除法运算"""
try:
if b == 0:
return "错误:除数不能为零"
result = a / b
return f"结果:{result}"
except TypeError:
return "错误:参数必须是数字"
except Exception as e:
return f"计算错误:{str(e)}"
定义工具的方式
LangChain 1.0 提供了多种定义工具的方式。
使用 @tool 装饰器
最简单的方式是使用 @tool 装饰器,LangChain 会自动从函数签名和 docstring 中提取工具的描述和参数模式:
from langchain.tools import tool
@tool
def multiply(a: int, b: int) -> int:
"""将两个数字相乘"""
return a * b
@tool
def get_word_length(word: str) -> int:
"""返回单词的长度"""
return len(word)
tools = [multiply, get_word_length]
使用 Pydantic 定义参数模式
对于复杂的参数验证,可以使用 Pydantic:
from langchain.tools import tool
from pydantic import BaseModel, Field, field_validator
import re
class EmailInput(BaseModel):
"""邮件输入参数"""
to: str = Field(description="收件人邮箱地址")
subject: str = Field(description="邮件主题")
body: str = Field(description="邮件正文")
@field_validator('to')
@classmethod
def validate_email(cls, v):
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, v):
raise ValueError('无效的邮箱地址')
return v
@field_validator('subject')
@classmethod
def validate_subject(cls, v):
if len(v) > 100:
raise ValueError('邮件主题不能超过100个字符')
return v
@tool(args_schema=EmailInput)
def send_email(to: str, subject: str, body: str) -> str:
"""发送电子邮件"""
# 实际发送邮件的逻辑
return f"邮件已发送至 {to}"
使用 StructuredTool
对于更复杂的场景,可以使用 StructuredTool:
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
class CalculatorInput(BaseModel):
"""计算器输入参数"""
expression: str = Field(description="数学表达式,如 '2 + 3 * 4'")
precision: int = Field(default=2, description="结果的小数位数")
def calculator_func(expression: str, precision: int = 2) -> str:
"""计算数学表达式"""
try:
result = eval(expression)
return f"{result:.{precision}f}"
except Exception as e:
return f"计算错误:{e}"
calculator = StructuredTool(
name="calculator",
description="计算数学表达式,支持加减乘除等基本运算",
func=calculator_func,
args_schema=CalculatorInput
)
异步工具
对于 I/O 密集型操作,可以使用异步工具:
@tool
async def async_search(query: str) -> str:
"""异步搜索互联网"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com/search?q={query}") as response:
data = await response.json()
return data.get("result", "未找到结果")
工具调用流程
当 Agent 决定使用工具时,执行流程如下:
直接调用工具
工具可以像普通函数一样直接调用:
from langchain_core.tools import tool
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b
# 方式一:普通函数调用
result = multiply.invoke({"a": 42, "b": 7})
print(result) # 输出:294
# 方式二:通过 tool_call 格式调用
tool_call = {
"type": "tool_call",
"id": "1",
"args": {"a": 42, "b": 7}
}
result = multiply.invoke(tool_call)
# 返回 ToolMessage 对象
访问上下文信息
工具有时需要访问运行时上下文,如用户 ID、会话状态等。LangChain 1.0 提供了多种方式实现这一点。
通过 context 参数访问配置
在 LangChain 1.0 中,推荐使用 context 参数传递运行时配置:
from langchain.tools import tool
from langchain.agents import create_agent
@tool
def get_user_info(user_id: str) -> str:
"""获取当前用户信息"""
return f"用户 ID: {user_id}"
agent = create_agent(
model="openai:gpt-4o-mini",
tools=[get_user_info]
)
# 调用时传入配置
result = agent.invoke(
{"messages": [{"role": "user", "content": "获取用户信息"}]},
context={"user_id": "user-123"}
)
通过 Middleware 访问状态
使用 middleware 可以更灵活地管理状态访问:
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain.tools import tool
from typing import TypedDict
class CustomState(TypedDict):
messages: list
user_name: str
@tool
def get_user_name() -> str:
"""从状态中获取用户名"""
# 工具会通过 middleware 获取状态
pass
class StateInjectionMiddleware(AgentMiddleware):
"""状态注入中间件"""
def wrap_tool_call(self, request, handler):
"""包装工具调用,注入状态"""
# 可以在这里访问状态并传递给工具
return (yield handler(request))
agent = create_agent(
model="openai:gpt-4o-mini",
tools=[get_user_name],
state_schema=CustomState,
middleware=[StateInjectionMiddleware()]
)
更新状态
在 LangChain 1.0 中,推荐使用 middleware 来管理状态更新:
from langchain.agents import create_agent, Middleware
from langchain.tools import tool
from langchain.messages import ToolMessage
class StateUpdateMiddleware(AgentMiddleware):
"""状态更新中间件"""
def after_tool_call(self, request, result, state, context):
"""工具调用后更新状态"""
if request["name"] == "update_user_name":
# 更新状态
return {"user_name": result}
return {}
通过 Store 访问长期记忆
使用 get_store() 访问长期记忆存储:
from langchain_core.tools import tool
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_store
from langgraph.prebuilt import create_react_agent
from langgraph.store.memory import InMemoryStore
# 创建存储
store = InMemoryStore()
@tool
def save_memory(content: str, config: RunnableConfig) -> str:
"""保存信息到长期记忆"""
memory_store = get_store()
user_id = config.get("configurable", {}).get("user_id", "default")
memory_store.put(
("memories", user_id), # 命名空间
"memory_1", # 键
{"content": content} # 值
)
return f"已保存:{content}"
@tool
def recall_memories(config: RunnableConfig) -> str:
"""检索长期记忆"""
memory_store = get_store()
user_id = config.get("configurable", {}).get("user_id", "default")
memories = memory_store.search(("memories", user_id))
return "\n".join([m.value["content"] for m in memories])
agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=[save_memory, recall_memories],
store=store
)
常用工具实现
网络搜索工具
from langchain_core.tools import tool
import requests
@tool
def web_search(query: str, num_results: int = 5) -> str:
"""
搜索互联网获取信息。
参数:
- query: 搜索关键词
- num_results: 返回结果数量,默认5条
返回搜索结果的摘要列表。
"""
try:
# 这里使用示例,实际项目中替换为真实的搜索 API
# 如 Google Custom Search、Tavily、SerpAPI 等
return f"搜索 '{query}' 的结果:找到相关信息..."
except Exception as e:
return f"搜索失败:{str(e)}"
数据库查询工具
from langchain_core.tools import tool
import sqlite3
@tool
def query_sqlite(query: str, db_path: str = "database.db") -> str:
"""
执行 SQLite 数据库查询。
参数:
- query: SQL 查询语句(仅支持 SELECT)
- db_path: 数据库文件路径
返回查询结果的 JSON 字符串。
"""
# 安全检查:只允许 SELECT 语句
if not query.strip().upper().startswith("SELECT"):
return "错误:只支持 SELECT 查询"
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(query)
columns = [description[0] for description in cursor.description]
rows = cursor.fetchall()
results = [dict(zip(columns, row)) for row in rows]
conn.close()
return str(results)
except Exception as e:
return f"查询错误:{str(e)}"
文件操作工具
from langchain_core.tools import tool
import os
@tool
def read_file(file_path: str) -> str:
"""
读取文件内容。
参数:
- file_path: 文件的完整路径
返回文件内容字符串。
"""
try:
if not os.path.exists(file_path):
return f"错误:文件不存在 {file_path}"
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 限制返回长度
return content[:10000] if len(content) > 10000 else content
except Exception as e:
return f"读取错误:{str(e)}"
@tool
def write_file(file_path: str, content: str) -> str:
"""
写入文件内容。
参数:
- file_path: 文件的完整路径
- content: 要写入的内容
返回操作结果。
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return f"成功写入文件:{file_path}"
except Exception as e:
return f"写入错误:{str(e)}"
@tool
def list_directory(dir_path: str) -> str:
"""
列出目录内容。
参数:
- dir_path: 目录路径
返回目录中的文件和子目录列表。
"""
try:
if not os.path.exists(dir_path):
return f"错误:目录不存在 {dir_path}"
items = os.listdir(dir_path)
result = []
for item in items:
full_path = os.path.join(dir_path, item)
if os.path.isdir(full_path):
result.append(f"[目录] {item}")
else:
size = os.path.getsize(full_path)
result.append(f"[文件] {item} ({size} bytes)")
return "\n".join(result)
except Exception as e:
return f"列出目录错误:{str(e)}"
HTTP 请求工具
from langchain_core.tools import tool
import requests
import json
@tool
def http_get(url: str, params: str = "{}") -> str:
"""
发送 HTTP GET 请求。
参数:
- url: 请求的 URL
- params: JSON 格式的查询参数,如 '{"key": "value"}'
返回响应内容。
"""
try:
params_dict = json.loads(params) if params else {}
response = requests.get(url, params=params_dict, timeout=30)
response.raise_for_status()
return response.text[:5000]
except json.JSONDecodeError:
return "错误:params 参数必须是有效的 JSON 格式"
except requests.exceptions.RequestException as e:
return f"请求错误:{str(e)}"
@tool
def http_post(url: str, data: str = "{}", headers: str = "{}") -> str:
"""
发送 HTTP POST 请求。
参数:
- url: 请求的 URL
- data: JSON 格式的请求体
- headers: JSON 格式的请求头
返回响应内容。
"""
try:
data_dict = json.loads(data) if data else {}
headers_dict = json.loads(headers) if headers else {}
response = requests.post(
url,
json=data_dict,
headers=headers_dict,
timeout=30
)
response.raise_for_status()
return response.text[:5000]
except json.JSONDecodeError:
return "错误:data 或 headers 参数必须是有效的 JSON 格式"
except requests.exceptions.RequestException as e:
return f"请求错误:{str(e)}"
工具高级特性
直接返回(return_direct)
使用 return_direct=True 可以让工具结果直接返回给用户,不再进行后续的 LLM 处理:
from langchain_core.tools import tool
@tool(return_direct=True)
def get_current_time() -> str:
"""获取当前时间(直接返回,不经过 LLM 处理)"""
from datetime import datetime
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 当调用这个工具时,结果会直接返回给用户
# 而不是再让 LLM 处理
强制工具调用
通过 tool_choice 参数可以强制模型使用特定的工具:
from langgraph.prebuilt import create_react_agent
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
@tool
def greet(user_name: str) -> str:
"""向用户问好"""
return f"你好,{user_name}!"
llm = ChatOpenAI(model="gpt-4o-mini")
tools = [greet]
# 强制使用 greet 工具
agent = create_react_agent(
model=llm.bind_tools(tools, tool_choice={"type": "tool", "name": "greet"}),
tools=tools
)
禁用并行工具调用
某些场景下需要禁用并行工具调用:
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
llm = ChatOpenAI(model="gpt-4o-mini")
# 禁用并行工具调用
agent = create_react_agent(
model=llm.bind_tools(tools, parallel_tool_calls=False),
tools=tools
)
错误处理
使用 ToolNode 处理错误
ToolNode 提供了内置的错误处理机制:
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool
@tool
def risky_operation(param: str) -> str:
"""可能失败的操作"""
if param == "error":
raise ValueError("操作失败")
return f"成功:{param}"
# 默认错误处理(捕获错误并返回错误消息)
tool_node = ToolNode([risky_operation])
# 自定义错误消息
tool_node = ToolNode(
[risky_operation],
handle_tool_errors="操作失败,请检查参数后重试。"
)
# 禁用错误处理(直接抛出异常)
tool_node = ToolNode([risky_operation], handle_tool_errors=False)
自定义错误处理函数
def custom_error_handler(error: Exception, tool_call: dict) -> str:
"""自定义错误处理函数"""
tool_name = tool_call.get("name", "unknown")
return f"工具 {tool_name} 执行失败:{str(error)}。请尝试其他方法。"
tool_node = ToolNode(
[risky_operation],
handle_tool_errors=custom_error_handler
)
工具安全
输入验证
使用 Pydantic 进行严格的输入验证:
from langchain_core.tools import tool
from pydantic import BaseModel, Field, field_validator
import re
class SecureInput(BaseModel):
"""安全输入验证"""
command: str = Field(description="要执行的命令")
@field_validator('command')
@classmethod
def validate_command(cls, v):
# 禁止危险命令
dangerous_commands = ['rm', 'del', 'format', 'shutdown']
if any(cmd in v.lower() for cmd in dangerous_commands):
raise ValueError('禁止执行危险命令')
return v
@tool(args_schema=SecureInput)
def execute_safe_command(command: str) -> str:
"""执行安全命令"""
# 实际执行逻辑
return f"执行命令:{command}"
敏感信息过滤
过滤工具返回中的敏感信息:
import re
def sanitize_output(output: str) -> str:
"""过滤敏感信息"""
patterns = [
(r'\b\d{16,19}\b', '[银行卡号]'),
(r'\b\d{17}[\dXx]\b', '[身份证号]'),
(r'\b[\w\.-]+@[\w\.-]+\.\w+\b', '[邮箱]'),
(r'\b1[3-9]\d{9}\b', '[手机号]'),
(r'(password|pwd|token|key|secret)["\']?\s*[:=]\s*["\']?[^\s"\']+', '[敏感信息]'),
]
for pattern, replacement in patterns:
output = re.sub(pattern, replacement, output, flags=re.IGNORECASE)
return output
@tool
def query_user_info(user_id: str) -> str:
"""查询用户信息"""
user_data = {
"id": user_id,
"name": "张三",
"phone": "13812345678",
"email": "[email protected]",
}
return sanitize_output(str(user_data))
工具测试
为工具编写单元测试:
import pytest
from unittest.mock import patch, MagicMock
def test_web_search():
"""测试网络搜索工具"""
with patch('requests.get') as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [
{"title": "测试结果", "link": "http://example.com", "snippet": "摘要"}
]
}
mock_get.return_value = mock_response
result = web_search.invoke({"query": "测试"})
assert "测试结果" in result
def test_calculator():
"""测试计算器工具"""
result = calculator.invoke({"expression": "2 + 3 * 4"})
assert "14" in result
def test_calculator_division_by_zero():
"""测试除零错误处理"""
result = divide.invoke({"a": 10, "b": 0})
assert "错误" in result
动态工具选择
当工具数量很多时,可以使用语义搜索动态选择相关工具:
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
@tool
def get_weather(city: str) -> str:
"""获取天气信息"""
return f"{city}:晴"
@tool
def get_news(topic: str) -> str:
"""获取新闻"""
return f"{topic} 相关新闻"
@tool
def calculate(expression: str) -> str:
"""计算数学表达式"""
return str(eval(expression))
@tool
def translate(text: str, target_lang: str) -> str:
"""翻译文本"""
return f"[{target_lang}] {text}"
# 所有可用工具
all_tools = [get_weather, get_news, calculate, translate]
agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=all_tools
)
# 模型会根据用户输入自动选择合适的工具
使用预构建工具
LLM 提供商工具
一些 LLM 提供商提供了预构建的工具,可以直接使用:
from langgraph.prebuilt import create_react_agent
# 使用 OpenAI 的 web_search_preview 工具
agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=[{"type": "web_search_preview"}]
)
result = agent.invoke({
"messages": [{"role": "user", "content": "今天有什么科技新闻?"}]
})
LangChain 工具集成
LangChain 提供了大量预构建的工具集成,包括:
- 搜索:Tavily、SerpAPI、Bing
- 代码执行:Python REPL
- 数据库:SQL、MongoDB
- Web 数据:Web scraping
- APIs:OpenWeatherMap、NewsAPI
from langchain_community.tools import TavilySearchResults
# 使用 Tavily 搜索工具
search_tool = TavilySearchResults(max_results=3)
agent = create_react_agent(
model="openai:gpt-4o-mini",
tools=[search_tool]
)
小结
工具系统是 AI Agent 的核心组件,设计良好的工具能够极大地扩展 Agent 的能力:
- 遵循单一职责原则,每个工具只做一件事
- 提供清晰的描述,帮助 LLM 正确使用工具
- 实现健壮的错误处理,返回有意义的错误信息
- 使用 InjectedState 和 Store 访问上下文信息
- 注意安全性,验证输入、过滤敏感信息
- 为工具编写测试,确保功能正确可靠
- 使用 ToolNode 的 handle_tool_errors 参数处理工具错误
下一章我们将学习记忆系统的设计与实现。