Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
227 lines
6.6 KiB
Python
227 lines
6.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""The tool functions used in the planner example."""
|
|
import asyncio
|
|
import json
|
|
import os
|
|
from collections import OrderedDict
|
|
from typing import AsyncGenerator
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.formatter import DashScopeChatFormatter
|
|
from agentscope.mcp import HttpStatelessClient, StdIOStatefulClient
|
|
from agentscope.message import Msg, TextBlock
|
|
from agentscope.model import DashScopeChatModel
|
|
from agentscope.pipeline import stream_printing_messages
|
|
from agentscope.tool import (
|
|
ToolResponse,
|
|
Toolkit,
|
|
write_text_file,
|
|
insert_text_file,
|
|
view_text_file,
|
|
)
|
|
|
|
|
|
class ResultModel(BaseModel):
|
|
"""
|
|
The result model used for the sub worker to summarize the task result.
|
|
"""
|
|
|
|
success: bool = Field(
|
|
description="Whether the task was successful or not.",
|
|
)
|
|
message: str = Field(
|
|
description=(
|
|
"The specific task result, should include necessary details, "
|
|
"e.g. the file path if any file is generated, the deviation, "
|
|
"and the error message if any."
|
|
),
|
|
)
|
|
|
|
|
|
def _convert_to_text_block(msgs: list[Msg]) -> list[TextBlock]:
|
|
# Collect all the content blocks
|
|
blocks: list = []
|
|
# Convert tool_use block into text block for streaming tool response
|
|
for _ in msgs:
|
|
for block in _.get_content_blocks():
|
|
if block["type"] == "text":
|
|
blocks.append(block)
|
|
|
|
elif block["type"] == "tool_use":
|
|
blocks.append(
|
|
TextBlock(
|
|
type="text",
|
|
text=f"Calling tool {block['name']} ...",
|
|
),
|
|
)
|
|
|
|
return blocks
|
|
|
|
|
|
async def create_worker(
|
|
task_description: str,
|
|
) -> AsyncGenerator[ToolResponse, None]:
|
|
"""Create a sub-worker to finish the given task.
|
|
|
|
Args:
|
|
task_description (`str`):
|
|
The description of the task to be done by the sub-worker, should
|
|
contain all the necessary information.
|
|
|
|
Returns:
|
|
`AsyncGenerator[ToolResponse, None]`:
|
|
An async generator yielding ToolResponse objects.
|
|
"""
|
|
toolkit = Toolkit()
|
|
|
|
# Gaode MCP client
|
|
if os.getenv("GAODE_API_KEY"):
|
|
toolkit.create_tool_group(
|
|
group_name="amap_tools",
|
|
description="Map-related tools, including geocoding, routing, and "
|
|
"place search.",
|
|
)
|
|
client = HttpStatelessClient(
|
|
name="amap_mcp",
|
|
transport="streamable_http",
|
|
url=f"https://mcp.amap.com/mcp?key={os.environ['GAODE_API_KEY']}",
|
|
)
|
|
await toolkit.register_mcp_client(client, group_name="amap_tools")
|
|
else:
|
|
print(
|
|
"Warning: GAODE_API_KEY not set in environment, skipping Gaode "
|
|
"MCP client registration.",
|
|
)
|
|
|
|
# Browser MCP client
|
|
toolkit.create_tool_group(
|
|
group_name="browser_tools",
|
|
description="Web browsing related tools.",
|
|
)
|
|
browser_client = StdIOStatefulClient(
|
|
name="playwright-mcp",
|
|
command="npx",
|
|
args=["@playwright/mcp@latest"],
|
|
)
|
|
await browser_client.connect()
|
|
await toolkit.register_mcp_client(
|
|
browser_client,
|
|
group_name="browser_tools",
|
|
)
|
|
|
|
# GitHub MCP client
|
|
if os.getenv("GITHUB_TOKEN"):
|
|
toolkit.create_tool_group(
|
|
group_name="github_tools",
|
|
description="GitHub related tools, including repository "
|
|
"search and code file retrieval.",
|
|
)
|
|
github_client = HttpStatelessClient(
|
|
name="github",
|
|
transport="streamable_http",
|
|
url="https://api.githubcopilot.com/mcp/",
|
|
headers={"Authorization": f"Bearer {os.getenv('GITHUB_TOKEN')}"},
|
|
)
|
|
await toolkit.register_mcp_client(
|
|
github_client,
|
|
group_name="github_tools",
|
|
)
|
|
|
|
else:
|
|
print(
|
|
"Warning: GITHUB_TOKEN not set in environment, skipping GitHub "
|
|
"MCP client registration.",
|
|
)
|
|
|
|
# Basic read/write tools
|
|
toolkit.register_tool_function(write_text_file)
|
|
toolkit.register_tool_function(insert_text_file)
|
|
toolkit.register_tool_function(view_text_file)
|
|
|
|
# Create a new sub-agent to finish the given task
|
|
sub_agent = ReActAgent(
|
|
name="Worker",
|
|
sys_prompt=f"""You're an agent named Worker.
|
|
|
|
## Your Target
|
|
Your target is to finish the given task with your tools.
|
|
|
|
## IMPORTANT
|
|
You MUST use the `{ReActAgent.finish_function_name}` to generate the final answer after finishing the task.
|
|
""", # noqa: E501 # pylint: disable=C0301
|
|
model=DashScopeChatModel(
|
|
model_name="qwen3-max",
|
|
api_key=os.environ["DASHSCOPE_API_KEY"],
|
|
),
|
|
enable_meta_tool=True,
|
|
formatter=DashScopeChatFormatter(),
|
|
toolkit=toolkit,
|
|
max_iters=20,
|
|
)
|
|
|
|
# disable the console output of the sub-agent
|
|
sub_agent.set_console_output_enabled(False)
|
|
|
|
# Collect the execution process content
|
|
msgs = OrderedDict()
|
|
|
|
# Wrap the sub-agent in a coroutine task to obtain the final
|
|
# structured output
|
|
result = []
|
|
|
|
async def call_sub_agent() -> None:
|
|
msg_res = await sub_agent(
|
|
Msg(
|
|
"user",
|
|
content=task_description,
|
|
role="user",
|
|
),
|
|
structured_model=ResultModel,
|
|
)
|
|
result.append(msg_res)
|
|
|
|
# Use stream_printing_message to get the streaming response as the
|
|
# sub-agent works
|
|
async for msg, _ in stream_printing_messages(
|
|
agents=[sub_agent],
|
|
coroutine_task=call_sub_agent(),
|
|
):
|
|
msgs[msg.id] = msg
|
|
|
|
# Collect all the content blocks
|
|
yield ToolResponse(
|
|
content=_convert_to_text_block(
|
|
list(msgs.values()),
|
|
),
|
|
stream=True,
|
|
is_last=False,
|
|
)
|
|
|
|
# Expose the interruption signal to the caller
|
|
if msg.metadata and msg.metadata.get("_is_interrupted", False):
|
|
raise asyncio.CancelledError()
|
|
|
|
# Obtain the last message from the coroutine task
|
|
if result:
|
|
yield ToolResponse(
|
|
content=[
|
|
*_convert_to_text_block(
|
|
list(msgs.values()),
|
|
),
|
|
TextBlock(
|
|
type="text",
|
|
text=json.dumps(
|
|
result[0].metadata,
|
|
indent=2,
|
|
ensure_ascii=False,
|
|
),
|
|
),
|
|
],
|
|
stream=True,
|
|
is_last=True,
|
|
)
|
|
|
|
await browser_client.close()
|