中间件集成工具使用
LangChain 1.0 提供了丰富的内置中间件,用于解决常见的 Agent 开发问题。
before_model (模型调用前)
在模型调用前执行,主要用于输入预处理。
| 中间件 | 作用 |
|---|---|
| SummarizationMiddleware | 上下文压缩,防止 Token 超限 |
| PIIMiddleware | PII 信息脱敏,保护用户隐私 |
| ModelCallLimitMiddleware | 模型调用限制,防止死循环 |
SummarizationMiddleware
自动压缩历史会话,减少 token 使用。
# ==================== SummarizationMiddleware 完整实现 ====================
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.runnables import ensure_config
from pydantic import BaseModel, Field
from typing import Optional
from dotenv import load_dotenv
import logging
from src.utils.llm import llm
# ==================== 1. 配置日志 ====================
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
load_dotenv(override=True)
# ==================== 2. 定义工具 ====================
@tool
def search_patent(query: str) -> str:
"""搜索专利数据库"""
return f"专利搜索结果: 找到与 '{query}' 相关的 3 项专利..."
@tool
def analyze_technology(tech_desc: str) -> str:
"""分析技术可行性"""
return f"技术分析: '{tech_desc}' 的实现可行性评估完成..."
tools = [search_patent, analyze_technology]
# ==================== 3. 定义上下文 ====================
class UserContext(BaseModel):
user_id: str = Field(..., description="用户唯一标识")
department: str = Field(..., description="所属部门")
max_history_tokens: Optional[int] = Field(default=1000, description="历史消息 token 阈值")
# ==================== 4. 配置中间件 ====================
summarization_middleware = SummarizationMiddleware(
model=llm,
max_tokens_before_summary=50, # 历史消息 token 数量超过 200 时触发压缩
messages_to_keep=5, # 保留最近 5 条消息
summary_prompt="请将以下对话历史进行摘要,保留关键决策点和技术细节:\n\n{messages}\n\n摘要:" # 摘要提示词
)
# ==================== 5. 创建 Agent ====================
agent = create_agent(
model=llm,
tools=tools,
middleware=[summarization_middleware],
context_schema=UserContext,
debug=True,
)
# ==================== 6. 执行测试 ====================
def run_summarization_test():
logger.info("开始 SummarizationMiddleware 测试")
# 创建长对话历史
long_history = [HumanMessage(content=f"问题 {i+1}: 如何评估某项技术的专利风险?") for i in range(20)]
logger.info(f"创建了 {len(long_history)} 条消息")
# 执行
result = agent.invoke(
{"messages": long_history},
context=UserContext(user_id="engineer_001", department="研发部"),
config=ensure_config({"configurable": {"thread_id": "session_001"}})
)
result_messages = result.get("messages", [])
logger.info(f"执行后消息数: {len(result_messages)}")
if len(result_messages) < len(long_history):
logger.info(f"中间件已触发!压缩了 {len(long_history) - len(result_messages)} 条消息")
return result
# ==================== 7. 运行测试 ====================
result = run_summarization_test()
logger.info("测试完成")
压缩后的消息结构
- 摘要消息 (HumanMessage) - "Here is a summary..."
- 问题 16
- 问题 17
- 问题 18
- 问题 19
- 问题 20
- AI 回复 (AIMessage)
工作流程
- 触发压缩:20 条消息的 token 超过 50 阈值
- 删除旧消息:添加 RemoveMessage(id='remove_all') 标记删除
- 生成摘要:LLM 生成中文摘要,包含关键决策点
- 保留最近 5 条:messages_to_keep=5 生效(问题 16-20)
- 正常回复:基于压缩后的上下文生成回答
PIIMiddleware
核心特性
- 自动PII检测:使用
from langchain.agents.middleware import PIIMiddleware - 智能脱敏:自动识别并处理敏感信息
- 多种策略:支持 block、redact、mask、hash 四种处理策略
- 无缝集成:在模型调用前自动处理,对业务逻辑透明
在模型调用前,中间件会自动:
- 扫描消息内容,识别指定类型的PII信息
- 根据策略处理敏感信息(阻止/脱敏/遮蔽/哈希)
- 将处理后的消息传递给模型
支持的PII类型
- email:电子邮件地址
- credit_card:信用卡号
- ip:IP地址
- mac_address:MAC地址
- url:URL地址
处理策略
- block:阻止包含PII的消息
- redact:完全移除PII信息
- mask:部分遮蔽PII信息
- hash:将PII转换为哈希值
# ========================================
# LangChain 1.0 信用卡PII掩码中间件实战
# ========================================
import os
from typing import Annotated
from langchain.agents import create_agent
from langchain.agents.middleware import PIIMiddleware
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool, BaseTool
from langchain_core.messages import HumanMessage, AIMessage
from pydantic import BaseModel, Field
import re
from dotenv import load_dotenv
# ==================== 1. 加载环境 ====================
load_dotenv(override=True)
# ==================== 2. 定义模拟工具 ====================
@tool
def verify_credit_card(card_number: Annotated[str, "信用卡号"]) -> dict:
"""
验证信用卡号有效性(模拟工具)
注意:实际生产环境中不应接收真实卡号
"""
# 工具接收到的参数已经是掩码后的
print(f"工具接收到的卡号: {card_number}")
# 模拟验证逻辑
if len(card_number) >= 16: # 掩码后的长度也足够判断
return {
"is_valid": True,
"card_type": "Visa",
"masked_card": card_number
}
return {"is_valid": False}
@tool
def process_payment(card_number: str, amount: float) -> str:
"""
处理信用卡支付(模拟工具)
"""
print(f"支付工具接收到的卡号: {card_number}")
return f"支付成功!金额: ${amount}, 卡号: {card_number}"
@tool
def search_user_history(user_id: str) -> str:
"""查询用户历史记录"""
return f"用户 {user_id} 的历史订单:订单123, 订单456"
# 工具列表
tools: list[BaseTool] = [verify_credit_card, process_payment, search_user_history]
# ==================== 3. 定义用户上下文 ====================
class UserContext(BaseModel):
"""用户上下文 Schema"""
user_id: str = Field(..., description="用户唯一标识")
department: str = Field(..., description="所属部门")
security_level: str = Field(default="normal", description="安全级别")
# ==================== 4. 配置 PIIMiddleware ====================
# 核心配置:信用卡掩码中间件
piim_credit_card = PIIMiddleware(
"credit_card",
detector=r"\b(?:\d{4}[-\s]?){3}\d{4}\b", # 匹配格式: 1234-5678-9012-3456
strategy="mask", # 掩码策略
apply_to_input=True, # 对输入消息进行掩码
apply_to_output=False, # 不对工具输出进行掩码(工具返回的是业务结果)
)
# ==================== 5. 创建智能体 ====================
agent = create_agent(
# 主模型:用于决策和对话
model=ChatDeepSeek(model="deepseek-chat", temperature=0.2),
# 工具列表
tools=tools,
# 中间件:只启用PII掩码(生产环境可添加日志等)
middleware=[
piim_credit_card, # 信用卡掩码中间件
],
# 启用上下文
context_schema=UserContext,
# 调试模式
debug=True,
)
# ==================== 6. 测试用例与执行 ====================
def test_credit_card_masking():
"""测试信用卡掩码全流程"""
print("=" * 60)
print("测试场景:用户尝试使用信用卡支付")
print("=" * 60)
# 测试输入:包含多种信用卡格式
test_query = """
请帮我验证以下信用卡是否有效:
我的卡号是 4532-1234-5678-9010,另外备用卡是 4532123456781234。
请检查这两张卡,然后处理一笔 99.99 美元的支付。
"""
print(f"\n【原始用户输入】\n{test_query}\n")
# 执行 Agent
result = agent.invoke(
# 消息列表
{"messages": [HumanMessage(content=test_query)]},
# 上下文(必须)
context=UserContext(
user_id="user_789",
department="财务部",
security_level="high"
),
# 配置(可选)
config={"configurable": {"thread_id": "session_cc_001"}}
)
print("\n【Agent 最终返回的消息】")
final_message = result["messages"][-1]
if isinstance(final_message, AIMessage):
print(f"角色: {final_message.type}")
print(f"内容: {final_message.content}")
# 检查工具调用
if hasattr(final_message, 'tool_calls') and final_message.tool_calls:
print("\n【工具调用记录】")
for tc in final_message.tool_calls:
print(f"- 工具: {tc['name']}")
print(f" 参数: {tc['args']}")
return result
test_credit_card_masking()