跳到主要内容

工具系统

工具是 AI Agent 与外部世界交互的桥梁。一个设计良好的工具系统能够让 Agent 突破 LLM 自身的限制,完成各种复杂任务。本章将深入介绍工具的设计原则、实现方式和最佳实践。

工具的本质

工具本质上是一个可以被 LLM 调用的函数。它接收结构化的参数输入,执行特定操作,返回结果。LLM 根据用户的请求和工具的描述,决定何时调用哪个工具,以及传递什么参数。

一个完整的工具定义包含以下要素:

  • 名称:工具的唯一标识符
  • 描述:说明工具的功能和使用场景
  • 参数模式:定义工具接受的参数类型和结构
  • 执行逻辑:实际执行操作的代码

工具设计原则

单一职责

每个工具应该只做一件事,并且把它做好。这有助于 LLM 正确理解工具的用途,也便于维护和测试。

@tool
def search_web(query: str) -> str:
"""搜索互联网获取信息"""
pass

@tool
def get_weather(city: str) -> str:
"""获取指定城市的天气信息"""
pass

@tool
def send_email(to: str, subject: str, body: str) -> str:
"""发送电子邮件"""
pass

清晰的描述

工具的描述直接决定了 LLM 能否正确使用它。描述应该包含:

  • 工具的功能说明
  • 适用场景
  • 参数的含义和格式要求
@tool
def query_database(
query: str,
database: str = "main"
) -> str:
"""
执行 SQL 查询并返回结果。

适用场景:
- 需要从数据库获取数据时
- 需要统计或分析数据时

参数说明:
- query: SQL 查询语句,只支持 SELECT 语句
- database: 数据库名称,默认为 main

注意:不支持 INSERT、UPDATE、DELETE 等修改操作
"""
pass

健壮的错误处理

工具应该优雅地处理各种错误情况,返回有意义的错误信息:

@tool
def divide(a: float, b: float) -> str:
"""执行除法运算"""
try:
if b == 0:
return "错误:除数不能为零"
result = a / b
return f"结果:{result}"
except TypeError:
return "错误:参数必须是数字"
except Exception as e:
return f"计算错误:{str(e)}"

LangChain 工具开发

LangChain 提供了多种定义工具的方式。

使用 @tool 装饰器

最简单的方式是使用 @tool 装饰器:

from langchain_core.tools import tool

@tool
def multiply(a: int, b: int) -> int:
"""将两个数字相乘"""
return a * b

@tool
def get_word_length(word: str) -> int:
"""返回单词的长度"""
return len(word)

tools = [multiply, get_word_length]

LangChain 会自动从函数签名和 docstring 中提取工具的描述和参数模式。

使用 StructuredTool

对于更复杂的场景,可以使用 StructuredTool

from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field

class CalculatorInput(BaseModel):
"""计算器输入参数"""
expression: str = Field(description="数学表达式,如 '2 + 3 * 4'")
precision: int = Field(default=2, description="结果的小数位数")

def calculator_func(expression: str, precision: int = 2) -> str:
"""计算数学表达式"""
try:
result = eval(expression)
return f"{result:.{precision}f}"
except Exception as e:
return f"计算错误:{e}"

calculator = StructuredTool(
name="calculator",
description="计算数学表达式,支持加减乘除等基本运算",
func=calculator_func,
args_schema=CalculatorInput
)

异步工具

对于 I/O 密集型操作,可以使用异步工具:

@tool
async def async_search(query: str) -> str:
"""异步搜索互联网"""
import aiohttp

async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com/search?q={query}") as response:
data = await response.json()
return data.get("result", "未找到结果")

常用工具实现

网络搜索工具

from langchain_core.tools import tool
import requests

@tool
def web_search(query: str, num_results: int = 5) -> str:
"""
搜索互联网获取信息。

参数:
- query: 搜索关键词
- num_results: 返回结果数量,默认5条

返回搜索结果的摘要列表。
"""
try:
api_key = "YOUR_API_KEY"
search_engine_id = "YOUR_SEARCH_ENGINE_ID"

url = "https://www.googleapis.com/customsearch/v1"
params = {
"key": api_key,
"cx": search_engine_id,
"q": query,
"num": num_results
}

response = requests.get(url, params=params)
data = response.json()

results = []
for item in data.get("items", []):
results.append({
"title": item.get("title"),
"link": item.get("link"),
"snippet": item.get("snippet")
})

return str(results)
except Exception as e:
return f"搜索失败:{str(e)}"

数据库查询工具

from langchain_core.tools import tool
import sqlite3

@tool
def query_sqlite(query: str, db_path: str = "database.db") -> str:
"""
执行 SQLite 数据库查询。

参数:
- query: SQL 查询语句(仅支持 SELECT)
- db_path: 数据库文件路径

返回查询结果的 JSON 字符串。
"""
if not query.strip().upper().startswith("SELECT"):
return "错误:只支持 SELECT 查询"

try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(query)

columns = [description[0] for description in cursor.description]
rows = cursor.fetchall()

results = [dict(zip(columns, row)) for row in rows]
conn.close()

return str(results)
except Exception as e:
return f"查询错误:{str(e)}"

文件操作工具

from langchain_core.tools import tool
import os

@tool
def read_file(file_path: str) -> str:
"""
读取文件内容。

参数:
- file_path: 文件的完整路径

返回文件内容字符串。
"""
try:
if not os.path.exists(file_path):
return f"错误:文件不存在 {file_path}"

with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()

return content[:10000]
except Exception as e:
return f"读取错误:{str(e)}"

@tool
def write_file(file_path: str, content: str) -> str:
"""
写入文件内容。

参数:
- file_path: 文件的完整路径
- content: 要写入的内容

返回操作结果。
"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)

with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)

return f"成功写入文件:{file_path}"
except Exception as e:
return f"写入错误:{str(e)}"

@tool
def list_directory(dir_path: str) -> str:
"""
列出目录内容。

参数:
- dir_path: 目录路径

返回目录中的文件和子目录列表。
"""
try:
if not os.path.exists(dir_path):
return f"错误:目录不存在 {dir_path}"

items = os.listdir(dir_path)
result = []

for item in items:
full_path = os.path.join(dir_path, item)
if os.path.isdir(full_path):
result.append(f"[目录] {item}")
else:
size = os.path.getsize(full_path)
result.append(f"[文件] {item} ({size} bytes)")

return "\n".join(result)
except Exception as e:
return f"列出目录错误:{str(e)}"

HTTP 请求工具

from langchain_core.tools import tool
import requests
import json

@tool
def http_get(url: str, params: str = "{}") -> str:
"""
发送 HTTP GET 请求。

参数:
- url: 请求的 URL
- params: JSON 格式的查询参数,如 '{"key": "value"}'

返回响应内容。
"""
try:
params_dict = json.loads(params) if params else {}
response = requests.get(url, params=params_dict, timeout=30)
response.raise_for_status()
return response.text[:5000]
except json.JSONDecodeError:
return "错误:params 参数必须是有效的 JSON 格式"
except requests.exceptions.RequestException as e:
return f"请求错误:{str(e)}"

@tool
def http_post(url: str, data: str = "{}", headers: str = "{}") -> str:
"""
发送 HTTP POST 请求。

参数:
- url: 请求的 URL
- data: JSON 格式的请求体
- headers: JSON 格式的请求头

返回响应内容。
"""
try:
data_dict = json.loads(data) if data else {}
headers_dict = json.loads(headers) if headers else {}

response = requests.post(
url,
json=data_dict,
headers=headers_dict,
timeout=30
)
response.raise_for_status()
return response.text[:5000]
except json.JSONDecodeError:
return "错误:data 或 headers 参数必须是有效的 JSON 格式"
except requests.exceptions.RequestException as e:
return f"请求错误:{str(e)}"

工具链与组合

顺序工具链

将多个工具按顺序执行:

from langchain_core.tools import tool

@tool
def fetch_data(url: str) -> str:
"""从 URL 获取数据"""
import requests
response = requests.get(url)
return response.text

@tool
def parse_json(json_str: str) -> str:
"""解析 JSON 字符串"""
import json
data = json.loads(json_str)
return str(data)

@tool
def extract_field(data_str: str, field: str) -> str:
"""从数据中提取指定字段"""
import ast
data = ast.literal_eval(data_str)
return str(data.get(field, "字段不存在"))

def tool_chain(url: str, field: str) -> str:
"""工具链示例:获取数据 -> 解析 -> 提取字段"""
json_data = fetch_data.invoke({"url": url})
parsed_data = parse_json.invoke({"json_str": json_data})
result = extract_field.invoke({"data_str": parsed_data, "field": field})
return result

条件工具选择

根据输入动态选择工具:

from typing import Literal
from langchain_core.tools import tool

@tool
def search_local(query: str) -> str:
"""在本地知识库中搜索"""
return f"本地搜索结果:{query}"

@tool
def search_web(query: str) -> str:
"""在互联网上搜索"""
return f"网络搜索结果:{query}"

def smart_search(query: str, source: Literal["local", "web"] = "web") -> str:
"""智能选择搜索源"""
if source == "local":
return search_local.invoke({"query": query})
else:
return search_web.invoke({"query": query})

工具安全

输入验证

使用 Pydantic 进行严格的输入验证:

from pydantic import BaseModel, Field, validator
from langchain_core.tools import StructuredTool
import re

class EmailInput(BaseModel):
"""邮件输入参数"""
to: str = Field(description="收件人邮箱地址")
subject: str = Field(description="邮件主题")
body: str = Field(description="邮件正文")

@validator('to')
def validate_email(cls, v):
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not re.match(pattern, v):
raise ValueError('无效的邮箱地址')
return v

@validator('subject')
def validate_subject(cls, v):
if len(v) > 100:
raise ValueError('邮件主题不能超过100个字符')
return v

def send_email(to: str, subject: str, body: str) -> str:
return f"邮件已发送至 {to}"

email_tool = StructuredTool(
name="send_email",
description="发送电子邮件",
func=send_email,
args_schema=EmailInput
)

权限控制

限制工具的访问权限:

from functools import wraps

def require_permission(permission: str):
"""权限检查装饰器"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
current_user = kwargs.get("user", {})
user_permissions = current_user.get("permissions", [])

if permission not in user_permissions:
return f"错误:没有执行此操作的权限(需要 {permission})"

return func(*args, **kwargs)
return wrapper
return decorator

@tool
@require_permission("file:write")
def delete_file(file_path: str, user: dict = None) -> str:
"""删除文件(需要 file:write 权限)"""
import os
try:
os.remove(file_path)
return f"已删除文件:{file_path}"
except Exception as e:
return f"删除失败:{str(e)}"

敏感信息过滤

过滤工具返回中的敏感信息:

import re

def sanitize_output(output: str) -> str:
"""过滤敏感信息"""
patterns = [
(r'\b\d{16,19}\b', '[银行卡号]'),
(r'\b\d{17}[\dXx]\b', '[身份证号]'),
(r'\b[\w\.-]+@[\w\.-]+\.\w+\b', '[邮箱]'),
(r'\b1[3-9]\d{9}\b', '[手机号]'),
(r'(password|pwd|token|key|secret)["\']?\s*[:=]\s*["\']?[^\s"\']+', '[敏感信息]'),
]

for pattern, replacement in patterns:
output = re.sub(pattern, replacement, output, flags=re.IGNORECASE)

return output

@tool
def query_user_info(user_id: str) -> str:
"""查询用户信息"""
user_data = {
"id": user_id,
"name": "张三",
"phone": "13812345678",
"email": "[email protected]",
"id_card": "110101199001011234"
}
return sanitize_output(str(user_data))

工具测试

为工具编写单元测试:

import pytest
from unittest.mock import patch, MagicMock

def test_web_search():
with patch('requests.get') as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {
"items": [
{"title": "测试结果", "link": "http://example.com", "snippet": "摘要"}
]
}
mock_get.return_value = mock_response

result = web_search.invoke({"query": "测试"})
assert "测试结果" in result

def test_calculator():
result = calculator.invoke({"expression": "2 + 3 * 4"})
assert "14" in result

def test_calculator_division_by_zero():
result = divide.invoke({"a": 10, "b": 0})
assert "错误" in result

小结

工具系统是 AI Agent 的核心组件,设计良好的工具能够极大地扩展 Agent 的能力:

  • 遵循单一职责原则,每个工具只做一件事
  • 提供清晰的描述,帮助 LLM 正确使用工具
  • 实现健壮的错误处理,返回有意义的错误信息
  • 注意安全性,验证输入、控制权限、过滤敏感信息
  • 为工具编写测试,确保功能正确可靠

下一章我们将学习记忆系统的设计与实现。

参考资料