Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
404 lines
11 KiB
Python
404 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
.. _middleware:
|
||
|
||
中间件
|
||
===========================
|
||
|
||
AgentScope 提供了灵活的中间件系统,允许开发者拦截和修改各种操作的执行。
|
||
目前,中间件支持已在 ``Toolkit`` 类中实现,用于**工具执行**。
|
||
|
||
中间件系统遵循**洋葱模型**,每个中间件包裹在前一个中间件之外,形成层次结构。
|
||
这使得开发者可以:
|
||
|
||
- 在操作前进行**预处理**
|
||
- 在执行过程中**拦截和修改**响应
|
||
- 在操作完成后进行**后处理**
|
||
- 根据条件**跳过**操作执行
|
||
|
||
.. tip:: 未来版本的 AgentScope 将扩展中间件支持到其他组件,如智能体和模型。
|
||
|
||
"""
|
||
import asyncio
|
||
from typing import AsyncGenerator, Callable
|
||
|
||
from agentscope.message import TextBlock, ToolUseBlock
|
||
from agentscope.tool import ToolResponse, Toolkit
|
||
|
||
|
||
# %%
|
||
# 工具执行中间件
|
||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
#
|
||
# ``Toolkit`` 类通过 ``register_middleware`` 方法支持工具执行的中间件。
|
||
# 每个中间件可以拦截工具调用并修改输入或输出。
|
||
#
|
||
# 中间件签名
|
||
# ------------------------------
|
||
#
|
||
# 中间件函数应具有以下签名:
|
||
#
|
||
# .. code-block:: python
|
||
#
|
||
# async def middleware(
|
||
# kwargs: dict,
|
||
# next_handler: Callable,
|
||
# ) -> AsyncGenerator[ToolResponse, None]:
|
||
# # 从 kwargs 访问参数
|
||
# tool_call = kwargs["tool_call"]
|
||
#
|
||
# # 预处理
|
||
# # ...
|
||
#
|
||
# # 调用下一个中间件或工具函数
|
||
# async for response in await next_handler(**kwargs):
|
||
# # 后处理
|
||
# yield response
|
||
#
|
||
# .. list-table:: 中间件参数
|
||
# :header-rows: 1
|
||
#
|
||
# * - 参数
|
||
# - 类型
|
||
# - 描述
|
||
# * - ``kwargs``
|
||
# - ``dict``
|
||
# - 上下文参数。当前包含 ``tool_call`` (ToolUseBlock)。未来版本可能包含更多参数。
|
||
# * - ``next_handler``
|
||
# - ``Callable``
|
||
# - 一个可调用对象,接受 kwargs dict 并返回产生 AsyncGenerator[ToolResponse] 的协程
|
||
# * - **返回值**
|
||
# - ``AsyncGenerator[ToolResponse, None]``
|
||
# - 产生 ToolResponse 对象的异步生成器
|
||
#
|
||
# 基本示例
|
||
# ------------------------------
|
||
#
|
||
# 以下是一个记录工具调用的简单中间件:
|
||
#
|
||
|
||
|
||
async def logging_middleware(
|
||
kwargs: dict,
|
||
next_handler: Callable,
|
||
) -> AsyncGenerator[ToolResponse, None]:
|
||
"""记录工具执行的中间件。"""
|
||
# 从 kwargs 访问工具调用
|
||
tool_call = kwargs["tool_call"]
|
||
|
||
# 预处理:在工具执行前记录日志
|
||
print(f"[中间件] 调用工具:{tool_call['name']}")
|
||
print(f"[中间件] 输入:{tool_call['input']}")
|
||
|
||
# 调用下一个处理器(另一个中间件或实际工具)
|
||
async for response in await next_handler(**kwargs):
|
||
# 后处理:记录响应
|
||
print(f"[中间件] 响应:{response.content[0]['text']}")
|
||
yield response
|
||
|
||
# 在所有响应产生后执行
|
||
print(f"[中间件] 工具 {tool_call['name']} 完成")
|
||
|
||
|
||
# %%
|
||
# 让我们将这个中间件注册到工具包并测试它:
|
||
#
|
||
|
||
|
||
async def search_tool(query: str) -> ToolResponse:
|
||
"""一个简单的搜索工具。
|
||
|
||
Args:
|
||
query (`str`):
|
||
搜索查询。
|
||
|
||
Returns:
|
||
`ToolResponse`:
|
||
搜索结果。
|
||
"""
|
||
return ToolResponse(
|
||
content=[
|
||
TextBlock(
|
||
type="text",
|
||
text=f"'{query}' 的搜索结果",
|
||
),
|
||
],
|
||
)
|
||
|
||
|
||
async def example_logging_middleware() -> None:
|
||
"""使用日志中间件的示例。"""
|
||
# 创建工具包并注册工具
|
||
toolkit = Toolkit()
|
||
toolkit.register_tool_function(search_tool)
|
||
|
||
# 注册中间件
|
||
toolkit.register_middleware(logging_middleware)
|
||
|
||
# 调用工具
|
||
result = await toolkit.call_tool_function(
|
||
ToolUseBlock(
|
||
type="tool_use",
|
||
id="1",
|
||
name="search_tool",
|
||
input={"query": "AgentScope"},
|
||
),
|
||
)
|
||
|
||
async for response in result:
|
||
print(f"\n[最终] {response.content[0]['text']}\n")
|
||
|
||
|
||
print("=" * 60)
|
||
print("示例 1:日志中间件")
|
||
print("=" * 60)
|
||
asyncio.run(example_logging_middleware())
|
||
|
||
# %%
|
||
# 修改输入和输出
|
||
# ------------------------------
|
||
#
|
||
# 中间件还可以修改工具调用的输入和响应内容:
|
||
#
|
||
|
||
|
||
async def transform_middleware(
|
||
kwargs: dict,
|
||
next_handler: Callable,
|
||
) -> AsyncGenerator[ToolResponse, None]:
|
||
"""转换输入和输出的中间件。"""
|
||
# 从 kwargs 访问工具调用
|
||
tool_call = kwargs["tool_call"]
|
||
|
||
# 预处理:修改输入
|
||
original_query = tool_call["input"]["query"]
|
||
tool_call["input"]["query"] = f"[已转换] {original_query}"
|
||
|
||
async for response in await next_handler(**kwargs):
|
||
# 后处理:修改响应
|
||
original_text = response.content[0]["text"]
|
||
response.content[0]["text"] = f"{original_text} [已修改]"
|
||
yield response
|
||
|
||
|
||
async def example_transform_middleware() -> None:
|
||
"""转换中间件的示例。"""
|
||
toolkit = Toolkit()
|
||
toolkit.register_tool_function(search_tool)
|
||
toolkit.register_middleware(transform_middleware)
|
||
|
||
result = await toolkit.call_tool_function(
|
||
ToolUseBlock(
|
||
type="tool_use",
|
||
id="2",
|
||
name="search_tool",
|
||
input={"query": "中间件"},
|
||
),
|
||
)
|
||
|
||
async for response in result:
|
||
print(f"结果:{response.content[0]['text']}")
|
||
|
||
|
||
print("\n" + "=" * 60)
|
||
print("示例 2:转换中间件")
|
||
print("=" * 60)
|
||
asyncio.run(example_transform_middleware())
|
||
|
||
# %%
|
||
# 授权中间件
|
||
# ------------------------------
|
||
#
|
||
# 可以使用中间件实现授权检查,如果未授权则跳过工具执行:
|
||
#
|
||
|
||
|
||
async def authorization_middleware(
|
||
kwargs: dict,
|
||
next_handler: Callable,
|
||
) -> AsyncGenerator[ToolResponse, None]:
|
||
"""检查授权的中间件。"""
|
||
# 从 kwargs 访问工具调用
|
||
tool_call = kwargs["tool_call"]
|
||
|
||
# 检查工具是否已授权(简单示例)
|
||
authorized_tools = {"search_tool"}
|
||
|
||
if tool_call["name"] not in authorized_tools:
|
||
# 跳过执行并直接返回错误
|
||
print(f"[授权] 工具 {tool_call['name']} 未授权")
|
||
yield ToolResponse(
|
||
content=[
|
||
TextBlock(
|
||
type="text",
|
||
text=f"错误:工具 '{tool_call['name']}' 未授权",
|
||
),
|
||
],
|
||
)
|
||
return
|
||
|
||
# 工具已授权,继续执行
|
||
print(f"[授权] 工具 {tool_call['name']} 已授权")
|
||
async for response in await next_handler(**kwargs):
|
||
yield response
|
||
|
||
|
||
async def unauthorized_tool(data: str) -> ToolResponse:
|
||
"""一个未授权的工具。
|
||
|
||
Args:
|
||
data (`str`):
|
||
一些数据。
|
||
|
||
Returns:
|
||
`ToolResponse`:
|
||
结果。
|
||
"""
|
||
return ToolResponse(
|
||
content=[TextBlock(type="text", text=f"处理 {data}")],
|
||
)
|
||
|
||
|
||
async def example_authorization_middleware() -> None:
|
||
"""授权中间件的示例。"""
|
||
toolkit = Toolkit()
|
||
toolkit.register_tool_function(search_tool)
|
||
toolkit.register_tool_function(unauthorized_tool)
|
||
toolkit.register_middleware(authorization_middleware)
|
||
|
||
# 尝试授权的工具
|
||
print("\n调用已授权的工具:")
|
||
result = await toolkit.call_tool_function(
|
||
ToolUseBlock(
|
||
type="tool_use",
|
||
id="3",
|
||
name="search_tool",
|
||
input={"query": "测试"},
|
||
),
|
||
)
|
||
async for response in result:
|
||
print(f"结果:{response.content[0]['text']}")
|
||
|
||
# 尝试未授权的工具
|
||
print("\n调用未授权的工具:")
|
||
result = await toolkit.call_tool_function(
|
||
ToolUseBlock(
|
||
type="tool_use",
|
||
id="4",
|
||
name="unauthorized_tool",
|
||
input={"data": "测试"},
|
||
),
|
||
)
|
||
async for response in result:
|
||
print(f"结果:{response.content[0]['text']}")
|
||
|
||
|
||
print("\n" + "=" * 60)
|
||
print("示例 3:授权中间件")
|
||
print("=" * 60)
|
||
asyncio.run(example_authorization_middleware())
|
||
|
||
# %%
|
||
# 多个中间件(洋葱模型)
|
||
# ------------------------------
|
||
#
|
||
# 当注册多个中间件时,它们形成类似洋葱的结构。
|
||
# 执行顺序遵循洋葱模型:
|
||
#
|
||
# - **预处理**:按照中间件注册的顺序执行
|
||
# - **后处理**:按相反顺序执行(从内到外)
|
||
#
|
||
# 这是因为实际的工具响应对象会通过中间件链传递,
|
||
# 每个中间件都会原地修改它。
|
||
#
|
||
|
||
|
||
async def middleware_1(
|
||
kwargs: dict,
|
||
next_handler: Callable,
|
||
) -> AsyncGenerator[ToolResponse, None]:
|
||
"""第一个中间件。"""
|
||
# 从 kwargs 访问工具调用
|
||
tool_call = kwargs["tool_call"]
|
||
|
||
# 预处理
|
||
print("[M1] 预处理")
|
||
tool_call["input"]["query"] += " [M1]"
|
||
|
||
async for response in await next_handler(**kwargs):
|
||
# 后处理
|
||
response.content[0]["text"] += " [M1]"
|
||
print("[M1] 后处理")
|
||
yield response
|
||
|
||
|
||
async def middleware_2(
|
||
kwargs: dict,
|
||
next_handler: Callable,
|
||
) -> AsyncGenerator[ToolResponse, None]:
|
||
"""第二个中间件。"""
|
||
# 从 kwargs 访问工具调用
|
||
tool_call = kwargs["tool_call"]
|
||
|
||
# 预处理
|
||
print("[M2] 预处理")
|
||
tool_call["input"]["query"] += " [M2]"
|
||
|
||
async for response in await next_handler(**kwargs):
|
||
# 后处理
|
||
response.content[0]["text"] += " [M2]"
|
||
print("[M2] 后处理")
|
||
yield response
|
||
|
||
|
||
async def example_multiple_middleware() -> None:
|
||
"""多个中间件的示例。"""
|
||
toolkit = Toolkit()
|
||
toolkit.register_tool_function(search_tool)
|
||
|
||
# 按顺序注册中间件
|
||
toolkit.register_middleware(middleware_1)
|
||
toolkit.register_middleware(middleware_2)
|
||
|
||
result = await toolkit.call_tool_function(
|
||
ToolUseBlock(
|
||
type="tool_use",
|
||
id="5",
|
||
name="search_tool",
|
||
input={"query": "测试"},
|
||
),
|
||
)
|
||
|
||
async for response in result:
|
||
print(f"\n最终结果:{response.content[0]['text']}")
|
||
|
||
|
||
print("\n" + "=" * 60)
|
||
print("示例 4:多个中间件(洋葱模型)")
|
||
print("=" * 60)
|
||
print("\n执行流程:")
|
||
print("M1 预处理 → M2 预处理 → 工具 → M2 后处理 → M1 后处理")
|
||
print()
|
||
asyncio.run(example_multiple_middleware())
|
||
|
||
# %%
|
||
# 使用场景
|
||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||
#
|
||
# 中间件系统适用于各种场景:
|
||
#
|
||
# - **日志和监控**:跟踪工具使用情况和性能
|
||
# - **授权**:控制对特定工具的访问
|
||
# - **速率限制**:限制工具调用的频率
|
||
# - **缓存**:缓存重复调用的工具响应
|
||
# - **错误处理**:添加重试逻辑或优雅降级
|
||
# - **输入验证**:验证和清理工具输入
|
||
# - **输出转换**:格式化或过滤工具输出
|
||
# - **指标收集**:收集有关工具使用情况的统计信息
|
||
#
|
||
# .. note::
|
||
# - 中间件按注册顺序应用
|
||
# - 同一个 ``ToolResponse`` 对象通过中间件链传递并原地修改
|
||
# - 中间件可以通过不调用 ``next_handler`` 来完全跳过工具执行
|
||
# - 所有中间件必须是产生 ``ToolResponse`` 对象的异步生成器函数
|