Files
tw2/docs/tutorial/zh_CN/src/task_middleware.py
jimi a842f1861f
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
chore: initial import of standalone agentscope project
2026-03-02 18:21:40 +08:00

404 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
.. _middleware:
中间件
===========================
AgentScope 提供了灵活的中间件系统,允许开发者拦截和修改各种操作的执行。
目前,中间件支持已在 ``Toolkit`` 类中实现,用于**工具执行**。
中间件系统遵循**洋葱模型**,每个中间件包裹在前一个中间件之外,形成层次结构。
这使得开发者可以:
- 在操作前进行**预处理**
- 在执行过程中**拦截和修改**响应
- 在操作完成后进行**后处理**
- 根据条件**跳过**操作执行
.. tip:: 未来版本的 AgentScope 将扩展中间件支持到其他组件,如智能体和模型。
"""
import asyncio
from typing import AsyncGenerator, Callable
from agentscope.message import TextBlock, ToolUseBlock
from agentscope.tool import ToolResponse, Toolkit
# %%
# 工具执行中间件
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# ``Toolkit`` 类通过 ``register_middleware`` 方法支持工具执行的中间件。
# 每个中间件可以拦截工具调用并修改输入或输出。
#
# 中间件签名
# ------------------------------
#
# 中间件函数应具有以下签名:
#
# .. code-block:: python
#
# async def middleware(
# kwargs: dict,
# next_handler: Callable,
# ) -> AsyncGenerator[ToolResponse, None]:
# # 从 kwargs 访问参数
# tool_call = kwargs["tool_call"]
#
# # 预处理
# # ...
#
# # 调用下一个中间件或工具函数
# async for response in await next_handler(**kwargs):
# # 后处理
# yield response
#
# .. list-table:: 中间件参数
# :header-rows: 1
#
# * - 参数
# - 类型
# - 描述
# * - ``kwargs``
# - ``dict``
# - 上下文参数。当前包含 ``tool_call`` (ToolUseBlock)。未来版本可能包含更多参数。
# * - ``next_handler``
# - ``Callable``
# - 一个可调用对象,接受 kwargs dict 并返回产生 AsyncGenerator[ToolResponse] 的协程
# * - **返回值**
# - ``AsyncGenerator[ToolResponse, None]``
# - 产生 ToolResponse 对象的异步生成器
#
# 基本示例
# ------------------------------
#
# 以下是一个记录工具调用的简单中间件:
#
async def logging_middleware(
kwargs: dict,
next_handler: Callable,
) -> AsyncGenerator[ToolResponse, None]:
"""记录工具执行的中间件。"""
# 从 kwargs 访问工具调用
tool_call = kwargs["tool_call"]
# 预处理:在工具执行前记录日志
print(f"[中间件] 调用工具:{tool_call['name']}")
print(f"[中间件] 输入:{tool_call['input']}")
# 调用下一个处理器(另一个中间件或实际工具)
async for response in await next_handler(**kwargs):
# 后处理:记录响应
print(f"[中间件] 响应:{response.content[0]['text']}")
yield response
# 在所有响应产生后执行
print(f"[中间件] 工具 {tool_call['name']} 完成")
# %%
# 让我们将这个中间件注册到工具包并测试它:
#
async def search_tool(query: str) -> ToolResponse:
"""一个简单的搜索工具。
Args:
query (`str`):
搜索查询。
Returns:
`ToolResponse`:
搜索结果。
"""
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"'{query}' 的搜索结果",
),
],
)
async def example_logging_middleware() -> None:
"""使用日志中间件的示例。"""
# 创建工具包并注册工具
toolkit = Toolkit()
toolkit.register_tool_function(search_tool)
# 注册中间件
toolkit.register_middleware(logging_middleware)
# 调用工具
result = await toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id="1",
name="search_tool",
input={"query": "AgentScope"},
),
)
async for response in result:
print(f"\n[最终] {response.content[0]['text']}\n")
print("=" * 60)
print("示例 1日志中间件")
print("=" * 60)
asyncio.run(example_logging_middleware())
# %%
# 修改输入和输出
# ------------------------------
#
# 中间件还可以修改工具调用的输入和响应内容:
#
async def transform_middleware(
kwargs: dict,
next_handler: Callable,
) -> AsyncGenerator[ToolResponse, None]:
"""转换输入和输出的中间件。"""
# 从 kwargs 访问工具调用
tool_call = kwargs["tool_call"]
# 预处理:修改输入
original_query = tool_call["input"]["query"]
tool_call["input"]["query"] = f"[已转换] {original_query}"
async for response in await next_handler(**kwargs):
# 后处理:修改响应
original_text = response.content[0]["text"]
response.content[0]["text"] = f"{original_text} [已修改]"
yield response
async def example_transform_middleware() -> None:
"""转换中间件的示例。"""
toolkit = Toolkit()
toolkit.register_tool_function(search_tool)
toolkit.register_middleware(transform_middleware)
result = await toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id="2",
name="search_tool",
input={"query": "中间件"},
),
)
async for response in result:
print(f"结果:{response.content[0]['text']}")
print("\n" + "=" * 60)
print("示例 2转换中间件")
print("=" * 60)
asyncio.run(example_transform_middleware())
# %%
# 授权中间件
# ------------------------------
#
# 可以使用中间件实现授权检查,如果未授权则跳过工具执行:
#
async def authorization_middleware(
kwargs: dict,
next_handler: Callable,
) -> AsyncGenerator[ToolResponse, None]:
"""检查授权的中间件。"""
# 从 kwargs 访问工具调用
tool_call = kwargs["tool_call"]
# 检查工具是否已授权(简单示例)
authorized_tools = {"search_tool"}
if tool_call["name"] not in authorized_tools:
# 跳过执行并直接返回错误
print(f"[授权] 工具 {tool_call['name']} 未授权")
yield ToolResponse(
content=[
TextBlock(
type="text",
text=f"错误:工具 '{tool_call['name']}' 未授权",
),
],
)
return
# 工具已授权,继续执行
print(f"[授权] 工具 {tool_call['name']} 已授权")
async for response in await next_handler(**kwargs):
yield response
async def unauthorized_tool(data: str) -> ToolResponse:
"""一个未授权的工具。
Args:
data (`str`):
一些数据。
Returns:
`ToolResponse`:
结果。
"""
return ToolResponse(
content=[TextBlock(type="text", text=f"处理 {data}")],
)
async def example_authorization_middleware() -> None:
"""授权中间件的示例。"""
toolkit = Toolkit()
toolkit.register_tool_function(search_tool)
toolkit.register_tool_function(unauthorized_tool)
toolkit.register_middleware(authorization_middleware)
# 尝试授权的工具
print("\n调用已授权的工具:")
result = await toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id="3",
name="search_tool",
input={"query": "测试"},
),
)
async for response in result:
print(f"结果:{response.content[0]['text']}")
# 尝试未授权的工具
print("\n调用未授权的工具:")
result = await toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id="4",
name="unauthorized_tool",
input={"data": "测试"},
),
)
async for response in result:
print(f"结果:{response.content[0]['text']}")
print("\n" + "=" * 60)
print("示例 3授权中间件")
print("=" * 60)
asyncio.run(example_authorization_middleware())
# %%
# 多个中间件(洋葱模型)
# ------------------------------
#
# 当注册多个中间件时,它们形成类似洋葱的结构。
# 执行顺序遵循洋葱模型:
#
# - **预处理**:按照中间件注册的顺序执行
# - **后处理**:按相反顺序执行(从内到外)
#
# 这是因为实际的工具响应对象会通过中间件链传递,
# 每个中间件都会原地修改它。
#
async def middleware_1(
kwargs: dict,
next_handler: Callable,
) -> AsyncGenerator[ToolResponse, None]:
"""第一个中间件。"""
# 从 kwargs 访问工具调用
tool_call = kwargs["tool_call"]
# 预处理
print("[M1] 预处理")
tool_call["input"]["query"] += " [M1]"
async for response in await next_handler(**kwargs):
# 后处理
response.content[0]["text"] += " [M1]"
print("[M1] 后处理")
yield response
async def middleware_2(
kwargs: dict,
next_handler: Callable,
) -> AsyncGenerator[ToolResponse, None]:
"""第二个中间件。"""
# 从 kwargs 访问工具调用
tool_call = kwargs["tool_call"]
# 预处理
print("[M2] 预处理")
tool_call["input"]["query"] += " [M2]"
async for response in await next_handler(**kwargs):
# 后处理
response.content[0]["text"] += " [M2]"
print("[M2] 后处理")
yield response
async def example_multiple_middleware() -> None:
"""多个中间件的示例。"""
toolkit = Toolkit()
toolkit.register_tool_function(search_tool)
# 按顺序注册中间件
toolkit.register_middleware(middleware_1)
toolkit.register_middleware(middleware_2)
result = await toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id="5",
name="search_tool",
input={"query": "测试"},
),
)
async for response in result:
print(f"\n最终结果:{response.content[0]['text']}")
print("\n" + "=" * 60)
print("示例 4多个中间件洋葱模型")
print("=" * 60)
print("\n执行流程:")
print("M1 预处理 → M2 预处理 → 工具 → M2 后处理 → M1 后处理")
print()
asyncio.run(example_multiple_middleware())
# %%
# 使用场景
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# 中间件系统适用于各种场景:
#
# - **日志和监控**:跟踪工具使用情况和性能
# - **授权**:控制对特定工具的访问
# - **速率限制**:限制工具调用的频率
# - **缓存**:缓存重复调用的工具响应
# - **错误处理**:添加重试逻辑或优雅降级
# - **输入验证**:验证和清理工具输入
# - **输出转换**:格式化或过滤工具输出
# - **指标收集**:收集有关工具使用情况的统计信息
#
# .. note::
# - 中间件按注册顺序应用
# - 同一个 ``ToolResponse`` 对象通过中间件链传递并原地修改
# - 中间件可以通过不调用 ``next_handler`` 来完全跳过工具执行
# - 所有中间件必须是产生 ``ToolResponse`` 对象的异步生成器函数