chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled

This commit is contained in:
codex-bot
2026-03-02 22:32:27 +08:00
commit a64378956a
584 changed files with 93604 additions and 0 deletions

182
src/agentscope/__init__.py Normal file
View File

@@ -0,0 +1,182 @@
# -*- coding: utf-8 -*-
# flake8: noqa: E402
# pylint: disable=wrong-import-position
"""The agentscope serialization module"""
import os
import warnings
from contextvars import ContextVar
from datetime import datetime
import requests
import shortuuid
from ._run_config import _ConfigCls
def _generate_random_suffix(length: int) -> str:
"""Generate a random suffix."""
return shortuuid.uuid()[:length]
# A thread and async safe global configuration instance
_config = _ConfigCls(
run_id=ContextVar("run_id", default=shortuuid.uuid()),
project=ContextVar(
"project",
default="UnnamedProject_At" + datetime.now().strftime("%Y%m%d"),
),
name=ContextVar(
"name",
default=datetime.now().strftime("%H%M%S_")
+ _generate_random_suffix(4),
),
created_at=ContextVar(
"created_at",
default=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
),
trace_enabled=ContextVar(
"trace_enabled",
default=False,
),
)
from . import exception
from . import module
from . import message
from . import model
from . import tool
from . import formatter
from . import memory
from . import agent
from . import session
from . import embedding
from . import token
from . import evaluate
from . import pipeline
from . import tracing
from . import rag
from . import a2a
from . import realtime
from ._logging import (
logger,
setup_logger,
)
from .hooks import _equip_as_studio_hooks
from ._version import __version__
# Raise each warning only once
warnings.filterwarnings("once", category=DeprecationWarning)
def init(
project: str | None = None,
name: str | None = None,
run_id: str | None = None,
logging_path: str | None = None,
logging_level: str = "INFO",
studio_url: str | None = None,
tracing_url: str | None = None,
) -> None:
"""Initialize the agentscope library.
Args:
project (`str | None`, optional):
The project name.
name (`str | None`, optional):
The name of the run.
run_id (`str | None`, optional):
The identity of a running instance, which can be an agent, or a
multi-agent system. The `run_id` is used in AgentScope-Studio to
distinguish different runs.
logging_path (`str | None`, optional):
The path to saving the log file. If not provided, logs will not be
saved.
logging_level (`str | None`, optional):
The logging level. Defaults to "INFO".
studio_url (`str | None`, optional):
The URL of the AgentScope Studio to connect to.
tracing_url (`str | None`, optional):
The URL of the tracing endpoint, which can connect to third-party
OpenTelemetry tracing platforms like Arize-Phoenix and Langfuse.
If not provided and `studio_url` is provided, it will send traces
to the AgentScope Studio's tracing endpoint.
"""
if project:
_config.project = project
if name:
_config.name = name
if run_id:
_config.run_id = run_id
setup_logger(logging_level, logging_path)
if studio_url:
# Register the run
data = {
"id": _config.run_id,
"project": _config.project,
"name": _config.name,
"timestamp": _config.created_at,
"pid": os.getpid(),
"status": "running",
# Deprecated fields
"run_dir": "",
}
response = requests.post(
url=f"{studio_url}/trpc/registerRun",
json=data,
)
response.raise_for_status()
from .agent import UserAgent, StudioUserInput
UserAgent.override_class_input_method(
StudioUserInput(
studio_url=studio_url,
run_id=_config.run_id,
max_retries=3,
),
)
_equip_as_studio_hooks(studio_url)
if tracing_url:
endpoint = tracing_url
else:
endpoint = studio_url.strip("/") + "/v1/traces" if studio_url else None
if endpoint:
from .tracing import setup_tracing
setup_tracing(endpoint=endpoint)
_config.trace_enabled = True
__all__ = [
# modules
"exception",
"module",
"message",
"model",
"tool",
"formatter",
"memory",
"agent",
"session",
"logger",
"embedding",
"token",
"evaluate",
"pipeline",
"tracing",
"rag",
"a2a",
# functions
"init",
"setup_logger",
"__version__",
]

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""The logger for agentscope."""
import logging
_DEFAULT_FORMAT = (
"%(asctime)s | %(levelname)-7s | "
"%(module)s:%(funcName)s:%(lineno)s - %(message)s"
)
logger = logging.getLogger("as")
def setup_logger(
level: str,
filepath: str | None = None,
) -> None:
"""Set up the agentscope logger.
Args:
level (`str`):
The logging level, chosen from "INFO", "DEBUG", "WARNING",
"ERROR", "CRITICAL".
filepath (`str | None`, optional):
The filepath to save the logging output.
"""
if level not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
raise ValueError(
f"Invalid logging level: {level}. Must be one of "
f"'INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'.",
)
logger.handlers.clear()
logger.setLevel(level)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
logger.addHandler(handler)
if filepath:
handler = logging.FileHandler(filepath)
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
logger.addHandler(handler)
logger.propagate = False
setup_logger("INFO")

View File

@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
"""The run instance configuration in agentscope."""
from contextvars import ContextVar
class _ConfigCls:
"""The run instance configuration in agentscope."""
def __init__(
self,
run_id: ContextVar[str],
project: ContextVar[str],
name: ContextVar[str],
created_at: ContextVar[str],
trace_enabled: ContextVar[bool],
) -> None:
"""The constructor for _Config class."""
# Copy the default context variables
self._run_id = run_id
self._created_at = created_at
self._project = project
self._name = name
self._trace_enabled = trace_enabled
@property
def run_id(self) -> str:
"""Get the run ID."""
return self._run_id.get()
@run_id.setter
def run_id(self, value: str) -> None:
"""Set the run ID."""
self._run_id.set(value)
@property
def created_at(self) -> str:
"""Get the creation time."""
return self._created_at.get()
@created_at.setter
def created_at(self, value: str) -> None:
"""Set the creation time."""
self._created_at.set(value)
@property
def project(self) -> str:
"""Get the project name."""
return self._project.get()
@project.setter
def project(self, value: str) -> None:
"""Set the project name."""
self._project.set(value)
@property
def name(self) -> str:
"""Get the run name."""
return self._name.get()
@name.setter
def name(self, value: str) -> None:
"""Set the run name."""
self._name.set(value)
@property
def trace_enabled(self) -> bool:
"""Get whether tracing is enabled."""
return self._trace_enabled.get()
@trace_enabled.setter
def trace_enabled(self, value: bool) -> None:
"""Set whether tracing is enabled."""
self._trace_enabled.set(value)

View File

View File

@@ -0,0 +1,474 @@
# -*- coding: utf-8 -*-
"""The common utilities for agentscope library."""
import asyncio
import base64
import functools
import inspect
import json
import os
import tempfile
import types
import typing
import uuid
from datetime import datetime
from typing import Any, Callable, Type, Dict
import numpy as np
import requests
from docstring_parser import parse
from json_repair import repair_json
from pydantic import BaseModel, Field, create_model, ConfigDict
from .._logging import logger
from ..types import ToolFunction
if typing.TYPE_CHECKING:
from mcp.types import Tool
else:
Tool = "mcp.types.Tool"
def _json_loads_with_repair(
json_str: str,
) -> dict:
"""The given json_str maybe incomplete, e.g. '{"key', so we need to
repair and load it into a Python object.
.. note::
This function is currently only used for parsing the streaming output
of the argument field in `tool_use`, so the parsed result must be a
dict.
Args:
json_str (`str`):
The JSON string to parse, which may be incomplete or malformed.
Returns:
`dict`:
A dictionary parsed from the JSON string after repair attempts.
Returns an empty dict if all repair attempts fail.
"""
try:
repaired = repair_json(json_str, stream_stable=True)
result = json.loads(repaired)
if isinstance(result, dict):
return result
except Exception:
if len(json_str) > 100:
log_str = json_str[:100] + "..."
else:
log_str = json_str
logger.warning(
"Failed to load JSON dict from string: %s. Returning empty dict "
"instead.",
log_str,
)
return {}
def _is_accessible_local_file(url: str) -> bool:
"""Check if the given URL is a local URL."""
return os.path.isfile(url)
def _get_timestamp(add_random_suffix: bool = False) -> str:
"""Get the current timestamp in the format YYYY-MM-DD HH:MM:SS.sss."""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
if add_random_suffix:
# Add a random suffix to the timestamp
timestamp += f"_{os.urandom(3).hex()}"
return timestamp
async def _is_async_func(func: Callable) -> bool:
"""Check if the given function is an async function, including
coroutine functions, async generators, and coroutine objects.
"""
return (
inspect.iscoroutinefunction(func)
or inspect.isasyncgenfunction(func)
or isinstance(func, types.CoroutineType)
or isinstance(func, types.GeneratorType)
and asyncio.iscoroutine(func)
or isinstance(func, functools.partial)
and await _is_async_func(func.func)
)
async def _execute_async_or_sync_func(
func: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
"""Execute an async or sync function based on its type.
Args:
func (`Callable`):
The function to be executed, which can be either async or sync.
*args (`Any`):
Positional arguments to be passed to the function.
**kwargs (`Any`):
Keyword arguments to be passed to the function.
Returns:
`Any`:
The result of the function execution.
"""
if await _is_async_func(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
def _get_bytes_from_web_url(
url: str,
max_retries: int = 3,
) -> str:
"""Get the bytes from a given URL.
Args:
url (`str`):
The URL to fetch the bytes from.
max_retries (`int`, defaults to `3`):
The maximum number of retries.
"""
for _ in range(max_retries):
try:
response = requests.get(url)
response.raise_for_status()
return response.content.decode("utf-8")
except UnicodeDecodeError:
return base64.b64encode(response.content).decode("ascii")
except Exception as e:
logger.info(
"Failed to fetch bytes from URL %s. Error %s. Retrying...",
url,
str(e),
)
raise RuntimeError(
f"Failed to fetch bytes from URL `{url}` after {max_retries} retries.",
)
def _save_base64_data(
media_type: str,
base64_data: str,
) -> str:
"""Save the base64 data to a temp file and return the file path. The
extension is guessed from the MIME type.
Args:
media_type (`str`):
The MIME type of the data, e.g. "image/png", "audio/mpeg".
base64_data (`str):
The base64 data to be saved.
"""
extension = "." + media_type.split("/")[-1]
with tempfile.NamedTemporaryFile(
suffix=f".{extension}",
delete=False,
) as temp_file:
decoded_data = base64.b64decode(base64_data)
temp_file.write(decoded_data)
temp_file.close()
return temp_file.name
def _extract_json_schema_from_mcp_tool(tool: Tool) -> dict[str, Any]:
"""Extract JSON schema from MCP tool."""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": tool.inputSchema.get(
"properties",
{},
),
"required": tool.inputSchema.get(
"required",
[],
),
},
},
}
def _remove_title_field(schema: dict) -> None:
"""Remove the title field from the JSON schema to avoid
misleading the LLM."""
# The top level title field
if "title" in schema:
schema.pop("title")
# properties
if "properties" in schema:
for prop in schema["properties"].values():
if isinstance(prop, dict):
_remove_title_field(prop)
# items
if "items" in schema and isinstance(schema["items"], dict):
_remove_title_field(schema["items"])
# additionalProperties
if "additionalProperties" in schema and isinstance(
schema["additionalProperties"],
dict,
):
_remove_title_field(
schema["additionalProperties"],
)
def _create_tool_from_base_model(
structured_model: Type[BaseModel],
tool_name: str = "generate_structured_output",
) -> Dict[str, Any]:
"""Create a function tool definition from a Pydantic BaseModel.
This function converts a Pydantic BaseModel class into a tool definition
that can be used with function calling API. The resulting tool
definition includes the model's JSON schema as parameters, enabling
structured output generation by forcing the model to call this function
with properly formatted data.
Args:
structured_model (`Type[BaseModel]`):
A Pydantic BaseModel class that defines the expected structure
for the tool's output.
tool_name (`str`, default `"generate_structured_output"`):
The tool name that used to force the LLM to generate structured
output by calling this function.
Returns:
`Dict[str, Any]`: A tool definition dictionary compatible with
function calling API, containing type ("function") and
function dictionary with name, description, and parameters
(JSON schema).
.. code-block:: python
:caption: Example usage
from pydantic import BaseModel
class PersonInfo(BaseModel):
name: str
age: int
email: str
tool = _create_tool_from_base_model(PersonInfo, "extract_person")
print(tool["function"]["name"]) # extract_person
print(tool["type"]) # function
.. note:: The function automatically removes the 'title' field from
the JSON schema to ensure compatibility with function calling
format. This is handled by the internal ``_remove_title_field()``
function.
"""
schema = structured_model.model_json_schema()
_remove_title_field(schema)
tool_definition = {
"type": "function",
"function": {
"name": tool_name,
"description": "Generate the required structured output with "
"this function",
"parameters": schema,
},
}
return tool_definition
def _map_text_to_uuid(text: str) -> str:
"""Map the given text to a deterministic UUID string.
Args:
text (`str`):
The input text to be mapped to a UUID.
Returns:
`str`:
A deterministic UUID string derived from the input text.
"""
return str(uuid.uuid3(uuid.NAMESPACE_DNS, text))
def _parse_tool_function(
tool_func: ToolFunction,
include_long_description: bool,
include_var_positional: bool,
include_var_keyword: bool,
) -> dict:
"""Extract JSON schema from the tool function's docstring
Args:
tool_func (`ToolFunction`):
The tool function to extract the JSON schema from.
include_long_description (`bool`):
Whether to include the long description in the JSON schema.
include_var_positional (`bool`):
Whether to include variable positional arguments in the JSON
schema.
include_var_keyword (`bool`):
Whether to include variable keyword arguments in the JSON schema.
Returns:
`dict`:
The extracted JSON schema.
"""
docstring = parse(tool_func.__doc__)
params_docstring = {_.arg_name: _.description for _ in docstring.params}
# Function description
descriptions = []
if docstring.short_description is not None:
descriptions.append(docstring.short_description)
if include_long_description and docstring.long_description is not None:
descriptions.append(docstring.long_description)
func_description = "\n".join(descriptions)
# Create a dynamic model with the function signature
fields = {}
for name, param in inspect.signature(tool_func).parameters.items():
# Skip the `self` and `cls` parameters
if name in ["self", "cls"]:
continue
# Handle `**kwargs`
if param.kind == inspect.Parameter.VAR_KEYWORD:
if not include_var_keyword:
continue
fields[name] = (
Dict[str, Any]
if param.annotation == inspect.Parameter.empty
else Dict[str, param.annotation], # type: ignore
Field(
description=params_docstring.get(
f"**{name}",
params_docstring.get(name, None),
),
default={}
if param.default is param.empty
else param.default,
),
)
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
if not include_var_positional:
continue
fields[name] = (
list[Any]
if param.annotation == inspect.Parameter.empty
else list[param.annotation], # type: ignore
Field(
description=params_docstring.get(
f"*{name}",
params_docstring.get(name, None),
),
default=[]
if param.default is param.empty
else param.default,
),
)
else:
fields[name] = (
Any
if param.annotation == inspect.Parameter.empty
else param.annotation,
Field(
description=params_docstring.get(name, None),
default=...
if param.default is param.empty
else param.default,
),
)
base_model = create_model(
"_StructuredOutputDynamicClass",
__config__=ConfigDict(arbitrary_types_allowed=True),
**fields,
)
params_json_schema = base_model.model_json_schema()
# Remove the title from the json schema
_remove_title_field(params_json_schema)
func_json_schema: dict = {
"type": "function",
"function": {
"name": tool_func.__name__,
"parameters": params_json_schema,
},
}
if func_description not in [None, ""]:
func_json_schema["function"]["description"] = func_description
return func_json_schema
def _resample_pcm_delta(
pcm_base64: str,
sample_rate: int,
target_rate: int,
) -> str:
"""Resampling the input pcm base64 data into the target rate.
Args:
pcm_base64 (`str`):
The input base64 audio data in pcm format.
sample_rate (`int`):
The sampling rate of the input data.
target_rate (`int`):
The target rate of the input data.
Returns:
`str`:
The resampling base64 audio data in the required sampling
rate.
"""
pcm_data = base64.b64decode(pcm_base64)
# Into numpy array first
audio_array = np.frombuffer(pcm_data, dtype=np.int16)
# return directly if the same
if sample_rate == target_rate:
return pcm_base64
# compute the number of samples
num_samples = int(len(audio_array) * target_rate / sample_rate)
from scipy import signal
# Use scipy to resample
resampled_audio = signal.resample(audio_array, num_samples)
# Turn it back into bytes
resampled_audio = np.clip(resampled_audio, -32768, 32767).astype(np.int16)
# into base64
resampled_bytes = resampled_audio.tobytes()
resampled_base64 = base64.b64encode(resampled_bytes).decode("utf-8")
return resampled_base64

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""The mixin for agentscope."""
class DictMixin(dict):
"""The dictionary mixin that allows attribute-style access."""
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__

View File

@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
"""The version of agentscope."""
__version__ = "1.0.16"

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
"""The A2A related modules."""
from ._base import AgentCardResolverBase
from ._file_resolver import FileAgentCardResolver
from ._well_known_resolver import WellKnownAgentCardResolver
from ._nacos_resolver import NacosAgentCardResolver
__all__ = [
"AgentCardResolverBase",
"FileAgentCardResolver",
"WellKnownAgentCardResolver",
"NacosAgentCardResolver",
]

View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
"""The A2A agent card resolver base class."""
from abc import abstractmethod
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from a2a.types import AgentCard
else:
AgentCard = "a2a.types.AgentCard"
class AgentCardResolverBase:
"""Base class for A2A agent card resolvers, responsible for fetching
agent cards from various sources. Implementations must provide the
`get_agent_card` method to retrieve the agent card.
"""
@abstractmethod
async def get_agent_card(self, *args: Any, **kwargs: Any) -> AgentCard:
"""Get Agent Card from the configured source.
Returns:
`AgentCard`:
The resolved agent card object.
"""

View File

@@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
"""The JSON file based A2A agent card resolver."""
import json
from pathlib import Path
from typing import TYPE_CHECKING
from ._base import AgentCardResolverBase
if TYPE_CHECKING:
from a2a.types import AgentCard
else:
AgentCard = "a2a.types.AgentCard"
class FileAgentCardResolver(AgentCardResolverBase):
"""Agent card resolver that loads AgentCard from a JSON file.
The JSON file should contain an AgentCard object with the following
required fields:
- name (str): The name of the agent
- url (str): The URL of the agent
- version (str): The version of the agent
- capabilities (dict): The capabilities of the agent
- default_input_modes (list[str]): Default input modes
- default_output_modes (list[str]): Default output modes
- skills (list): List of agent skills
Example JSON file content:
.. code-block:: json
{
"name": "RemoteAgent",
"url": "http://localhost:8000",
"description": "A remote A2A agent",
"version": "1.0.0",
"capabilities": {},
"default_input_modes": ["text/plain"],
"default_output_modes": ["text/plain"],
"skills": []
}
"""
def __init__(
self,
file_path: str,
) -> None:
"""Initialize the FileAgentCardResolver with the path to the JSON file.
Args:
file_path (`str`):
The path to the JSON file containing the agent card.
"""
self._file_path = file_path
async def get_agent_card(self) -> AgentCard:
"""Get the agent card from the JSON file.
Returns:
`AgentCard`:
The agent card loaded from the file.
"""
from a2a.types import AgentCard
path = Path(self._file_path)
if not path.exists():
raise FileNotFoundError(
f"Agent card file not found: {self._file_path}",
)
if not path.is_file():
raise ValueError(f"Path is not a file: {self._file_path}")
with path.open("r", encoding="utf-8") as f:
agent_json_data = json.load(f)
return AgentCard.model_validate(agent_json_data)

View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
"""The Nacos-based A2A Agent Card resolver."""
from typing import TYPE_CHECKING
from ._base import AgentCardResolverBase
from .._logging import logger
if TYPE_CHECKING:
from a2a.types import AgentCard
from v2.nacos.common.client_config import ClientConfig
else:
AgentCard = "a2a.types.AgentCard"
ClientConfig = "v2.nacos.common.client_config.ClientConfig"
class NacosAgentCardResolver(AgentCardResolverBase):
"""Nacos-based A2A Agent Card resolver.
Nacos is a dynamic service discovery, configuration and service
management platform for building cloud native applications. This resolver
fetches the agent card from a Nacos server and subscribes to updates.
"""
def __init__(
self,
remote_agent_name: str,
nacos_client_config: ClientConfig,
version: str | None = None,
) -> None:
"""Initialize the nacos agent card resolver.
Args:
remote_agent_name (`str`):
Name of the remote agent in Nacos.
nacos_client_config (`ClientConfig | None`, optional):
Nacos client configuration, where a `server_addresses`
parameter is required.
version (`str | None`, optional):
Version of the agent card to fetch. If None, fetches the
latest version. This version is also used when subscribing
to agent card updates.
Defaults to None (latest version).
"""
if not remote_agent_name:
raise ValueError(
"The remote_agent_name cannot be empty.",
)
if not nacos_client_config:
raise ValueError(
"The nacos_client_config cannot be None.",
)
self._nacos_client_config = nacos_client_config
self._remote_agent_name = remote_agent_name
self._version = version
async def get_agent_card(self) -> AgentCard:
"""Get agent card from Nacos with lazy initialization.
Returns:
`AgentCard`:
The resolved agent card from Nacos.
"""
try:
from v2.nacos.ai.model.ai_param import GetAgentCardParam
from v2.nacos.ai.nacos_ai_service import NacosAIService
except ImportError as e:
raise ImportError(
"Please install the nacos sdk by running `pip install "
"nacos-sdk-python>=3.0.0` first.",
) from e
client = None
try:
client = await NacosAIService.create_ai_service(
self._nacos_client_config,
)
await client.start()
return await client.get_agent_card(
GetAgentCardParam(
agent_name=self._remote_agent_name,
version=self._version,
),
)
finally:
if client:
# Close the Nacos client to free resources
try:
await client.shutdown()
except Exception as e:
logger.warning(
"Failed to shutdown Nacos client: %s",
str(e),
)

View File

@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
"""The A2A well-known agent card resolver."""
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from ._base import AgentCardResolverBase
from .._logging import logger
if TYPE_CHECKING:
from a2a.types import AgentCard
else:
AgentCard = "a2a.types.AgentCard"
class WellKnownAgentCardResolver(AgentCardResolverBase):
"""Agent card resolver that loads AgentCard from a well-known URL."""
def __init__(
self,
base_url: str,
agent_card_path: str | None = None,
) -> None:
"""Initialize the WellKnownAgentCardResolver.
Args:
base_url (`str`):
The base URL to resolve the agent card from.
agent_card_path (`str | None`, optional):
The path to the agent card relative to the base URL.
Defaults to AGENT_CARD_WELL_KNOWN_PATH from a2a.utils.
"""
self._base_url = base_url
self._agent_card_path = agent_card_path
async def get_agent_card(self) -> AgentCard:
"""Get the agent card from the well-known URL.
Returns:
`AgentCard`:
The agent card loaded from the URL.
"""
import httpx
from a2a.client import A2ACardResolver
from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH
try:
parsed_url = urlparse(self._base_url)
if not parsed_url.scheme or not parsed_url.netloc:
logger.error(
"[%s] Invalid URL format: %s",
self.__class__.__name__,
self._base_url,
)
raise ValueError(
f"Invalid URL format: {self._base_url}",
)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
relative_card_path = parsed_url.path
# Use default path if not specified
agent_card_path = (
self._agent_card_path
if self._agent_card_path is not None
else AGENT_CARD_WELL_KNOWN_PATH
)
# Use async context manager to ensure proper cleanup
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout=600),
) as _http_client:
resolver = A2ACardResolver(
httpx_client=_http_client,
base_url=base_url,
agent_card_path=agent_card_path,
)
return await resolver.get_agent_card(
relative_card_path=relative_card_path,
)
except Exception as e:
logger.error(
"[%s] Failed to resolve agent card from URL %s: %s",
self.__class__.__name__,
self._base_url,
e,
)
raise RuntimeError(
f"Failed to resolve AgentCard from URL "
f"{self._base_url}: {e}",
) from e

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""The agent base class."""
from ._agent_base import AgentBase
from ._react_agent_base import ReActAgentBase
from ._react_agent import ReActAgent
from ._user_input import (
UserInputBase,
UserInputData,
TerminalUserInput,
StudioUserInput,
)
from ._user_agent import UserAgent
from ._a2a_agent import A2AAgent
from ._realtime_agent import RealtimeAgent
__all__ = [
"AgentBase",
"ReActAgentBase",
"ReActAgent",
"UserInputData",
"UserInputBase",
"TerminalUserInput",
"StudioUserInput",
"UserAgent",
"A2AAgent",
"RealtimeAgent",
]

View File

@@ -0,0 +1,288 @@
# -*- coding: utf-8 -*-
"""A2A agent implementation for AgentScope.
This module provides the A2A (Agent-to-Agent) protocol implementation,
enabling AgentScope agents to communicate with remote agents using the
A2A standard protocol.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Type
import httpx
from pydantic import BaseModel
from ._agent_base import AgentBase
from ..message import Msg
from ..formatter import A2AChatFormatter
if TYPE_CHECKING:
from a2a.types import AgentCard
from a2a.client import ClientConfig, Consumer
from a2a.client.client_factory import TransportProducer
else:
AgentCard = "a2a.types.AgentCard"
ClientConfig = "a2a.client.ClientConfig"
Consumer = "a2a.client.Consumer"
TransportProducer = "a2a.client.client_factory.TransportProducer"
class A2AAgent(AgentBase):
"""An A2A agent implementation in AgentScope, which supports
- Communication with remote agents using the A2A protocol
- Bidirectional message conversion between AgentScope and A2A formats
- Task lifecycle management with streaming and polling
- Artifact handling and status tracking
.. note:: Due to the limitation of A2A protocol. The A2AAgent class
- Only support chatbot-scenario (a user and an assistant) interactions.
To support multi-agent interactions requires the server side to handle
the `name` field in the A2A messages properly.
- Does not support structured output in `reply()` method due to the lack
of structured output support in A2A protocol.
- Stores observed messages locally and merges them with input messages
when `reply()` is called. Observed messages are cleared after processing.
"""
def __init__(
self,
agent_card: AgentCard,
client_config: ClientConfig | None = None,
consumers: list[Consumer] | None = None,
additional_transport_producers: dict[str, TransportProducer]
| None = None,
) -> None:
"""Initialize the A2A agent instance by the given agent card.
Args:
agent_card (`AgentCard`):
The agent card that contains the information about the remote
agent, such as its URL and capabilities.
client_config (`ClientConfig | None`, optional):
The configuration for the A2A client, including transport
preferences and streaming options.
consumers (`list[Consumer] | None`, optional):
The list of consumers for handling A2A client events.
These intercept request/response flows for logging,
metrics, and security.
additional_transport_producers (`dict[str, TransportProducer] | \
None`, optional):
Additional transport producers for creating A2A clients
with specific transport protocols.
"""
super().__init__()
from a2a.types import AgentCard
from a2a.client import ClientConfig, ClientFactory
if not isinstance(agent_card, AgentCard):
raise ValueError(
f"agent_card must be an instance of AgentCard, "
f"got {type(agent_card)}",
)
self.name: str = agent_card.name
self.agent_card = agent_card
# Create the client factory so that we can create clients later
# in reply()
self._a2a_client_factory = ClientFactory(
config=client_config
or ClientConfig(
httpx_client=httpx.AsyncClient(
timeout=httpx.Timeout(timeout=600),
),
),
consumers=consumers,
)
# Register additional transport producers if provided
if additional_transport_producers:
for label, producer in additional_transport_producers.items():
self._a2a_client_factory.register(
label,
producer,
)
# The variables to store observed messages
self._observed_msgs: list[Msg] = []
# The formatter used for message conversion
self.formatter = A2AChatFormatter()
def state_dict(self) -> dict:
"""Get the state dictionary of the A2A agent.
Returns:
`dict`:
The state dictionary containing the observed messages.
"""
return {
"_observed_msgs": [msg.to_dict() for msg in self._observed_msgs],
}
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
"""Load the state dictionary into the module.
Args:
state_dict (`dict`):
The state dictionary to load.
strict (`bool`, defaults to `True`):
If `True`, raises an error if any key in the module is not
found in the state_dict. If `False`, skips missing keys.
Raises:
`KeyError`:
If a required key is missing in the state_dict when strict
is `True`.
"""
if "_observed_msgs" in state_dict:
self._observed_msgs = [
Msg.from_dict(d) for d in state_dict["_observed_msgs"]
]
else:
raise KeyError(
"_observed_msgs key not found in state_dict",
)
if strict:
for key in state_dict.keys():
if key != "_observed_msgs":
raise KeyError(f"Unexpected key {key} in state_dict")
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Receive the given message(s) without generating a reply.
The observed messages are stored and will be merged with the
input messages when `reply` is called. After `reply` completes,
the stored messages will be cleared.
Args:
msg (`Msg | list[Msg] | None`):
The message(s) to be observed. If None, no action is taken.
"""
if msg is None:
return
if isinstance(msg, Msg):
self._observed_msgs.append(msg)
elif isinstance(msg, list) and all(isinstance(m, Msg) for m in msg):
self._observed_msgs.extend(msg)
else:
raise TypeError(
f"msg must be a Msg or a list of Msg, got {type(msg)}",
)
async def reply(
self,
msg: Msg | list[Msg] | None = None,
**kwargs: Any,
) -> Msg:
"""Send message(s) to the remote A2A agent and receive a response.
.. note:: This method merges any previously observed messages with the
input messages, sends them to the remote agent, and clears the
observed messages after processing.
.. note:: The A2A protocol does not support structured output, so the
`structured_model` parameter is not supported in this method.
Args:
msg (`Msg | list[Msg] | None`, optional):
The message(s) to send to the remote agent. Can be a single
Msg, a list of Msgs, or None. If None, only observed messages
will be sent. Defaults to None.
Returns:
`Msg`:
The response message from the remote agent. For tasks, this
may be either a status update message or the final artifacts
message, depending on the task state. If no messages are
provided (both msg and observed messages are empty), returns
a prompt message. If an error occurs during communication,
returns an error message.
"""
if "structured_model" in kwargs:
raise ValueError(
"structured_model is not supported in A2AAgent.reply() "
"due to the lack of structured output support in A2A "
"protocol.",
)
from a2a.types import Message as A2AMessage
# Merge observed messages with input messages
msgs_list = self._observed_msgs
if msg is not None:
if isinstance(msg, Msg):
msgs_list.append(msg)
else:
msgs_list.extend(msg)
# Create A2A client and send message
client = self._a2a_client_factory.create(
card=self.agent_card,
)
# Convert Msg objects into A2A Message object
a2a_message = await self.formatter.format([_ for _ in msgs_list if _])
response_msg = None
async for item in client.send_message(a2a_message):
if isinstance(item, A2AMessage):
response_msg = await self.formatter.format_a2a_message(
self.name,
item,
)
await self.print(response_msg)
elif isinstance(item, tuple):
task, _ = item
if task is not None:
for _ in await self.formatter.format_a2a_task(
self.name,
task,
):
await self.print(_)
response_msg = _
# Clear the observed messages after processing
self._observed_msgs.clear()
if response_msg:
return response_msg
raise ValueError(
"No response received from remote agent",
)
# pylint: disable=unused-argument
async def handle_interrupt(
self,
msg: Msg | list[Msg] | None = None,
structured_model: Type[BaseModel] | None = None,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else.
"""
response_msg = Msg(
self.name,
"I noticed that you have interrupted me. What can I "
"do for you?",
"assistant",
metadata={
# Expose this field to indicate the interruption
"_is_interrupted": True,
},
)
await self.print(response_msg, True)
# Add to observed messages for context in next reply
self._observed_msgs.append(response_msg)
return response_msg

View File

@@ -0,0 +1,736 @@
# -*- coding: utf-8 -*-
"""The agent base class in agentscope."""
import asyncio
import io
import json
import os
from asyncio import Task, Queue
from collections import OrderedDict
from copy import deepcopy
from typing import Callable, Any
import base64
import shortuuid
import numpy as np
from typing_extensions import deprecated
from ._agent_meta import _AgentMeta
from .._logging import logger
from ..module import StateModule
from ..message import (
Msg,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
ImageBlock,
VideoBlock,
)
from ..types import AgentHookTypes
class AgentBase(StateModule, metaclass=_AgentMeta):
"""Base class for asynchronous agents."""
id: str
"""The agent's unique identifier, generated using shortuuid."""
supported_hook_types: list[str] = [
"pre_reply",
"post_reply",
"pre_print",
"post_print",
"pre_observe",
"post_observe",
]
"""Supported hook types for the agent base class."""
_class_pre_reply_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before the reply
function, taking `self` object, the input arguments as input, and
generating the modified arguments (if needed). Then input arguments of the
reply function will be re-organized into a keyword arguments dictionary.
If the one hook returns a new dictionary, the modified arguments will be
passed to the next hook or the original reply function."""
_class_post_reply_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
Msg, # output, the output message
],
Msg | None,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the reply
function, which takes the `self` object and deep copied
positional and keyword arguments (args and kwargs), and the output message
as input. If the hook returns a message, the new message will be passed
to the next hook or the original reply function. Otherwise, the original
output will be passed instead."""
_class_pre_print_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before printing,
which takes the `self` object, a deep copied arguments dictionary as input,
and output the modified arguments (if needed). """
_class_post_print_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
Any, # output, `None` if no output
],
Any,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the speak
function, which takes the `self` object as input."""
_class_pre_observe_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before the observe
function, which takes the `self` object and a deep copied input
arguments dictionary as input. To change the input arguments, the hook
function needs to output the modified arguments dictionary, which will be
used as the input of the next hook function or the original observe
function."""
_class_post_observe_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
None, # The output, `None` if no output
],
None,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the observe
function, which takes the `self` object as input."""
def __init__(self) -> None:
"""Initialize the agent."""
super().__init__()
self.id = shortuuid.uuid()
# The replying task and identify of the current replying
self._reply_task: Task | None = None
self._reply_id: str | None = None
# Initialize the instance-level hooks
self._instance_pre_print_hooks = OrderedDict()
self._instance_post_print_hooks = OrderedDict()
self._instance_pre_reply_hooks = OrderedDict()
self._instance_post_reply_hooks = OrderedDict()
self._instance_pre_observe_hooks = OrderedDict()
self._instance_post_observe_hooks = OrderedDict()
# The prefix used in streaming printing, which will save the
# accumulated text and audio streaming data for each message id.
# e.g. {"text": "xxx", "audio": (stream_obj, "{base64_data}")}
self._stream_prefix = {}
# The subscribers that will receive the reply message by their
# `observe` method. The key is the MsgHub id, and the value is the
# list of agents.
self._subscribers: dict[str, list[AgentBase]] = {}
# We add this variable in case developers want to disable the console
# output of the agent, e.g., in a production environment.
self._disable_console_output: bool = (
os.getenv(
"AGENTSCOPE_DISABLE_CONSOLE_OUTPUT",
"false",
).lower()
== "true"
)
# The streaming message queue used to export the messages as a
# generator
self._disable_msg_queue: bool = True
self.msg_queue = None
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Receive the given message(s) without generating a reply.
Args:
msg (`Msg | list[Msg] | None`):
The message(s) to be observed.
"""
raise NotImplementedError(
f"The observe function is not implemented in"
f" {self.__class__.__name__} class.",
)
async def reply(self, *args: Any, **kwargs: Any) -> Msg:
"""The main logic of the agent, which generates a reply based on the
current state and input arguments."""
raise NotImplementedError(
"The reply function is not implemented in "
f"{self.__class__.__name__} class.",
)
async def print(
self,
msg: Msg,
last: bool = True,
speech: AudioBlock | list[AudioBlock] | None = None,
) -> None:
"""The function to display the message.
Args:
msg (`Msg`):
The message object to be printed.
last (`bool`, defaults to `True`):
Whether this is the last one in streaming messages. For
non-streaming message, this should always be `True`.
speech (`AudioBlock | list[AudioBlock] | None`, optional):
The audio content block(s) to be played along with the
message.
"""
if not self._disable_msg_queue:
await self.msg_queue.put((deepcopy(msg), last, speech))
# Yield control to the event loop, allowing consumer coroutines
# to process messages from the queue. This prevents the producer
# from monopolizing the event loop.
await asyncio.sleep(0)
if self._disable_console_output:
return
# The accumulated textual content to print, including the text blocks
# and the thinking blocks
thinking_and_text_to_print = []
for block in msg.get_content_blocks():
if block["type"] == "text":
self._print_text_block(
msg.id,
name_prefix=msg.name,
text_content=block["text"],
thinking_and_text_to_print=thinking_and_text_to_print,
)
elif block["type"] == "thinking":
self._print_text_block(
msg.id,
name_prefix=f"{msg.name}(thinking)",
text_content=block["thinking"],
thinking_and_text_to_print=thinking_and_text_to_print,
)
elif last:
self._print_last_block(block, msg)
# Play audio block if exists
if isinstance(speech, list):
for audio_block in speech:
self._process_audio_block(msg.id, audio_block)
elif isinstance(speech, dict):
self._process_audio_block(msg.id, speech)
# Clean up resources if this is the last message in streaming
if last and msg.id in self._stream_prefix:
if "audio" in self._stream_prefix[msg.id]:
player, _ = self._stream_prefix[msg.id]["audio"]
# Close the miniaudio player
player.close()
stream_prefix = self._stream_prefix.pop(msg.id)
if "text" in stream_prefix and not stream_prefix["text"].endswith(
"\n",
):
print()
def _process_audio_block(
self,
msg_id: str,
audio_block: AudioBlock,
) -> None:
"""Process audio block content.
Args:
msg_id (`str`):
The unique identifier of the message
audio_block (`AudioBlock`):
The audio content block
"""
if "source" not in audio_block:
raise ValueError(
"The audio block must contain the 'source' field.",
)
if audio_block["source"]["type"] == "url":
import urllib.request
import wave
import sounddevice as sd
url = audio_block["source"]["url"]
try:
with urllib.request.urlopen(url) as response:
audio_data = response.read()
with wave.open(io.BytesIO(audio_data), "rb") as wf:
samplerate = wf.getframerate()
n_frames = wf.getnframes()
audio_frames = wf.readframes(n_frames)
# Convert byte data to numpy array
audio_np = np.frombuffer(audio_frames, dtype=np.int16)
# Play audio
sd.play(audio_np, samplerate)
sd.wait()
except Exception as e:
logger.error(
"Failed to play audio from url %s: %s",
url,
str(e),
)
elif audio_block["source"]["type"] == "base64":
data = audio_block["source"]["data"]
if msg_id not in self._stream_prefix:
self._stream_prefix[msg_id] = {}
audio_prefix = self._stream_prefix[msg_id].get("audio", None)
import sounddevice as sd
# The player and the prefix data is cached for streaming audio
if audio_prefix:
player, audio_prefix_data = audio_prefix
else:
player = sd.OutputStream(
samplerate=24000,
channels=1,
dtype=np.float32,
blocksize=1024,
latency="low",
)
player.start()
audio_prefix_data = ""
# play the audio data
new_audio_data = data[len(audio_prefix_data) :]
if new_audio_data:
audio_bytes = base64.b64decode(new_audio_data)
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
audio_float = audio_np.astype(np.float32) / 32768.0
# Write to the audio output stream
player.write(audio_float)
# save the player and the prefix data
self._stream_prefix[msg_id]["audio"] = (
player,
data,
)
else:
raise ValueError(
"Unsupported audio source type: "
f"{audio_block['source']['type']}",
)
def _print_text_block(
self,
msg_id: str,
name_prefix: str,
text_content: str,
thinking_and_text_to_print: list[str],
) -> None:
"""Print the text block and thinking block content.
Args:
msg_id (`str`):
The unique identifier of the message
name_prefix (`str`):
The prefix for the message, e.g. "{name}: " for text block and
"{name}(thinking): " for thinking block.
text_content (`str`):
The textual content to be printed.
thinking_and_text_to_print (`list[str]`):
A list of textual content to be printed together. Here we
gather the text and thinking blocks to print them together.
"""
thinking_and_text_to_print.append(
f"{name_prefix}: {text_content}",
)
# The accumulated text and thinking blocks to print
to_print = "\n".join(thinking_and_text_to_print)
# The text prefix that has been printed
if msg_id not in self._stream_prefix:
self._stream_prefix[msg_id] = {}
text_prefix = self._stream_prefix[msg_id].get("text", "")
# Only print when there is new text content
if len(to_print) > len(text_prefix):
print(to_print[len(text_prefix) :], end="")
# Save the printed text prefix
self._stream_prefix[msg_id]["text"] = to_print
def _print_last_block(
self,
block: ToolUseBlock
| ToolResultBlock
| ImageBlock
| VideoBlock
| AudioBlock,
msg: Msg,
) -> None:
"""Process and print the last content block, and the block type
is not text, or thinking.
Args:
block (`ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock \
| AudioBlock`):
The content block to be printed
msg (`Msg`):
The message object
"""
# TODO: We should consider how to handle the multimodal blocks in the
# terminal, since the base64 data may be too long to display.
if block.get("type") in ["image", "video", "audio"]:
return
text_prefix = self._stream_prefix.get(msg.id, {}).get("text", "")
if text_prefix:
# Add a newline to separate from previous text content
print_newline = "" if text_prefix.endswith("\n") else "\n"
print(
f"{print_newline}"
f"{json.dumps(block, indent=4, ensure_ascii=False)}",
)
else:
print(
f"{msg.name}:"
f" {json.dumps(block, indent=4, ensure_ascii=False)}",
)
async def __call__(self, *args: Any, **kwargs: Any) -> Msg:
"""Call the reply function with the given arguments."""
self._reply_id = shortuuid.uuid()
reply_msg: Msg | None = None
try:
self._reply_task = asyncio.current_task()
reply_msg = await self.reply(*args, **kwargs)
# The interruption is triggered by calling the interrupt method
except asyncio.CancelledError:
reply_msg = await self.handle_interrupt(*args, **kwargs)
finally:
# Broadcast the reply message to all subscribers
if reply_msg:
await self._broadcast_to_subscribers(reply_msg)
self._reply_task = None
return reply_msg
async def _broadcast_to_subscribers(
self,
msg: Msg | list[Msg] | None,
) -> None:
"""Broadcast the message to all subscribers."""
for subscribers in self._subscribers.values():
for subscriber in subscribers:
await subscriber.observe(msg)
async def handle_interrupt(
self,
*args: Any,
**kwargs: Any,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else."""
raise NotImplementedError(
f"The handle_interrupt function is not implemented in "
f"{self.__class__.__name__}",
)
async def interrupt(self, msg: Msg | list[Msg] | None = None) -> None:
"""Interrupt the current reply process."""
if self._reply_task and not self._reply_task.done():
self._reply_task.cancel(msg)
def register_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
"""Register a hook to the agent instance, which only takes effect
for the current instance.
Args:
hook_type (`str`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook. If the name is already registered, the
hook will be overwritten.
hook (`Callable`):
The hook function.
"""
if not isinstance(self, AgentBase):
raise TypeError(
"The register_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks[hook_name] = hook
def remove_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
"""Remove an instance-level hook from the agent instance.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook to remove.
"""
if not isinstance(self, AgentBase):
raise TypeError(
"The remove_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{self.__class__.__name__} instance.",
)
@classmethod
def register_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
"""The universal function to register a hook to the agent class, which
will take effect for all instances of the class.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook. If the name is already registered, the
hook will be overwritten.
hook (`Callable`):
The hook function.
"""
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks[hook_name] = hook
@classmethod
def remove_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
"""Remove a class-level hook from the agent class.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook to remove.
"""
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{cls.__name__} class.",
)
@classmethod
def clear_class_hooks(
cls,
hook_type: AgentHookTypes | None = None,
) -> None:
"""Clear all class-level hooks.
Args:
hook_type (`AgentHookTypes`, optional):
The type of the hook to clear. If not specified, all
class-level hooks will be cleared.
"""
if hook_type is None:
for typ in cls.supported_hook_types:
hooks = getattr(cls, f"_class_{typ}_hooks")
hooks.clear()
else:
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks.clear()
def clear_instance_hooks(
self,
hook_type: AgentHookTypes | None = None,
) -> None:
"""If `hook_type` is not specified, clear all instance-level hooks.
Otherwise, clear the specified type of instance-level hooks."""
if hook_type is None:
for typ in self.supported_hook_types:
if not hasattr(self, f"_instance_{typ}_hooks"):
raise ValueError(
f"Call super().__init__() in the constructor "
f"to initialize the instance-level hooks for "
f"{self.__class__.__name__}.",
)
hooks = getattr(self, f"_instance_{typ}_hooks")
hooks.clear()
else:
assert (
hook_type in self.supported_hook_types
), f"Invalid hook type: {hook_type}"
if not hasattr(self, f"_instance_{hook_type}_hooks"):
raise ValueError(
f"Call super().__init__() in the constructor "
f"to initialize the instance-level hooks for "
f"{self.__class__.__name__}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks.clear()
def reset_subscribers(
self,
msghub_name: str,
subscribers: list["AgentBase"],
) -> None:
"""Reset the subscribers of the agent.
Args:
msghub_name (`str`):
The name of the MsgHub that manages the subscribers.
subscribers (`list[AgentBase]`):
A list of agents that will receive the reply message from
this agent via their `observe` method.
"""
self._subscribers[msghub_name] = [_ for _ in subscribers if _ != self]
def remove_subscribers(self, msghub_name: str) -> None:
"""Remove the msghub subscribers by the given msg hub name.
Args:
msghub_name (`str`):
The name of the MsgHub that manages the subscribers.
"""
if msghub_name not in self._subscribers:
logger.warning(
"MsgHub named '%s' not found",
msghub_name,
)
else:
self._subscribers.pop(msghub_name)
@deprecated("Please use set_console_output_enabled() instead.")
def disable_console_output(self) -> None:
"""This function will disable the console output of the agent, e.g.
in a production environment to avoid messy logs."""
self._disable_console_output = True
def set_console_output_enabled(self, enabled: bool) -> None:
"""Enable or disable the console output of the agent. E.g. in a
production environment, you may want to disable the console output to
avoid messy logs.
Args:
enabled (`bool`):
If `True`, enable the console output. If `False`, disable
the console output.
"""
self._disable_console_output = not enabled
def set_msg_queue_enabled(
self,
enabled: bool,
queue: Queue | None = None,
) -> None:
"""Enable or disable the message queue for streaming outputs.
Args:
enabled (`bool`):
If `True`, enable the message queue to allow streaming
outputs. If `False`, disable the message queue.
queue (`Queue | None`, optional):
The queue instance that will be used to initialize the
message queue when `enable` is `True`.
"""
if enabled:
if queue is None:
if self.msg_queue is None:
self.msg_queue = asyncio.Queue(maxsize=100)
else:
self.msg_queue = queue
else:
self.msg_queue = None
self._disable_msg_queue = not enabled

View File

@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""The metaclass for agents in agentscope."""
import inspect
from copy import deepcopy
from functools import wraps
from typing import (
Any,
Dict,
TYPE_CHECKING,
Callable,
)
from .._utils._common import _execute_async_or_sync_func
if TYPE_CHECKING:
from ._agent_base import AgentBase
else:
AgentBase = "AgentBase"
def _normalize_to_kwargs(
func: Callable,
self: Any,
*args: Any,
**kwargs: Any,
) -> dict:
"""Normalize the provided positional and keyword arguments into a
keyword arguments dictionary that matches the function signature."""
sig = inspect.signature(func)
try:
# Bind the provided arguments to the function signature
bound = sig.bind(self, *args, **kwargs)
# Apply the default values for parameters
bound.apply_defaults()
# Return the arguments in a dictionary format
res = dict(bound.arguments)
res.pop("self")
return res
except TypeError as e:
# If failed to bind, we raise a TypeError with more context
param_names = list(sig.parameters.keys())
provided_args = len(args)
provided_kwargs = list(kwargs.keys())
raise TypeError(
f"Failed to bind parameters for function '{func.__name__}': {e}\n"
f"Expected parameters: {param_names}\n"
f"Provided {provided_args} positional args and kwargs: "
f"{provided_kwargs}",
) from e
def _wrap_with_hooks(
original_func: Callable,
) -> Callable:
"""A decorator to wrap the original async function with pre- and post-hooks
Args:
original_func (`Callable`):
The original async function to be wrapped with hooks.
"""
func_name = original_func.__name__.replace("_", "")
@wraps(original_func)
async def async_wrapper(
self: AgentBase,
*args: Any,
**kwargs: Any,
) -> Any:
"""The wrapped function, which call the pre- and post-hooks before and
after the original function."""
# Unify all positional and keyword arguments into a keyword arguments
normalized_kwargs = _normalize_to_kwargs(
original_func,
self,
*args,
**kwargs,
)
current_normalized_kwargs = normalized_kwargs
assert (
hasattr(self, f"_instance_pre_{func_name}_hooks")
and hasattr(self, f"_instance_post_{func_name}_hooks")
and hasattr(self.__class__, f"_class_pre_{func_name}_hooks")
and hasattr(self.__class__, f"_class_post_{func_name}_hooks")
), f"Hooks for {func_name} not found in {self.__class__.__name__}"
# pre-hooks
pre_hooks = list(
getattr(self, f"_instance_pre_{func_name}_hooks").values(),
) + list(
getattr(self, f"_class_pre_{func_name}_hooks").values(),
)
for pre_hook in pre_hooks:
modified_keywords = await _execute_async_or_sync_func(
pre_hook,
self,
deepcopy(current_normalized_kwargs),
)
if modified_keywords is not None:
assert isinstance(modified_keywords, dict), (
f"Pre-hook must return a dict of keyword arguments, rather"
f" than {type(modified_keywords)} from hook "
f"{pre_hook.__name__}"
)
current_normalized_kwargs = modified_keywords
# original function
# handle positional and keyword arguments specifically
args = current_normalized_kwargs.get("args", [])
kwargs = current_normalized_kwargs.get("kwargs", {})
others = {
k: v
for k, v in current_normalized_kwargs.items()
if k not in ["args", "kwargs"]
}
current_output = await original_func(
self,
*args,
**others,
**kwargs,
)
# post_hooks
post_hooks = list(
getattr(self, f"_instance_post_{func_name}_hooks").values(),
) + list(
getattr(self, f"_class_post_{func_name}_hooks").values(),
)
for post_hook in post_hooks:
modified_output = await _execute_async_or_sync_func(
post_hook,
self,
deepcopy(current_normalized_kwargs),
deepcopy(current_output),
)
if modified_output is not None:
current_output = modified_output
return current_output
return async_wrapper
class _AgentMeta(type):
"""The agent metaclass that wraps the agent's reply, observe and print
functions with pre- and post-hooks."""
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the agent's functions with hooks."""
for func_name in [
"reply",
"print",
"observe",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
return super().__new__(mcs, name, bases, attrs)
class _ReActAgentMeta(_AgentMeta):
"""The ReAct metaclass that adds pre- and post-hooks for the _reasoning
and _acting functions."""
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the ReAct agent's _reasoning and _acting functions with
hooks."""
for func_name in [
"_reasoning",
"_acting",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
return super().__new__(mcs, name, bases, attrs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
"""The base class for ReAct agent in agentscope."""
from abc import abstractmethod
from collections import OrderedDict
from typing import Callable, Any
from ._agent_base import AgentBase
from ._agent_meta import _ReActAgentMeta
from ..message import Msg
class ReActAgentBase(AgentBase, metaclass=_ReActAgentMeta):
"""The ReAct agent base class.
To support ReAct algorithm, this class extends the AgentBase class by
adding two abstract interfaces: reasoning and acting, while supporting
hook functions at four positions: pre-reasoning, post-reasoning,
pre-acting, and post-acting by the `_ReActAgentMeta` metaclass.
"""
supported_hook_types: list[str] = [
"pre_reply",
"post_reply",
"pre_print",
"post_print",
"pre_observe",
"post_observe",
"pre_reasoning",
"post_reasoning",
"pre_acting",
"post_acting",
]
"""Supported hook types for the agent base class."""
_class_pre_reasoning_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level pre-reasoning hooks, taking `self` object, the input
arguments as input"""
_class_post_reasoning_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
Any, # output
],
Msg | None, # the modified output message or None
],
] = OrderedDict()
"""The class-level post-reasoning hooks, taking `self` object, the input
arguments and the output message as input, and return the modified output
message or None if no modification is needed."""
_class_pre_acting_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level pre-acting hooks, taking `self` object, the input
arguments as input, and return the modified input arguments or None if no
modification is needed."""
_class_post_acting_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
Any, # output
],
Msg | None, # the modified output message or None
],
] = OrderedDict()
"""The class-level post-acting hooks, taking `self` object, the input
arguments and the output message as input, and return the modified output
message or None if no modification is needed."""
def __init__(
self,
) -> None:
"""Initialize the ReAct agent base class."""
super().__init__()
# Init reasoning and acting hooks
self._instance_pre_reasoning_hooks = OrderedDict()
self._instance_post_reasoning_hooks = OrderedDict()
self._instance_pre_acting_hooks = OrderedDict()
self._instance_post_acting_hooks = OrderedDict()
@abstractmethod
async def _reasoning(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""The reasoning process of the ReAct agent, which will be wrapped
with pre- and post-hooks."""
@abstractmethod
async def _acting(self, *args: Any, **kwargs: Any) -> Any:
"""The acting process of the ReAct agent, which will be wrapped with
pre- and post-hooks."""

View File

@@ -0,0 +1,360 @@
# -*- coding: utf-8 -*-
"""The realtime agent class."""
import asyncio
from asyncio import Queue
import shortuuid
from .._logging import logger
from .._utils._common import _resample_pcm_delta
from ..message import (
AudioBlock,
Base64Source,
TextBlock,
ImageBlock,
ToolUseBlock,
ToolResultBlock,
)
from ..module import StateModule
from ..realtime import (
ModelEvents,
RealtimeModelBase,
ServerEvents,
ClientEvents,
)
from ..tool import Toolkit
class RealtimeAgent(StateModule):
"""The realtime agent class. Different from the `AgentBase` class,
this class is designed for real-time interaction scenarios, such as
realtime chat, voice assistants, etc.
Example:
This realtime agent requires a queue to handle outgoing messages to
the frontend and other agents, and its lifecycle is managed by the
`start` and `stop` methods.
.. code-block:: python
:caption: An example of using the RealtimeAgent class.
from agentscope.agent import RealtimeAgent
from agentscope.realtime import DashScopeRealtimeModel
import asyncio
agent = RealtimeAgent(
name="Friday",
sys_prompt="You are a helpful assistant.",
model=DashScopeRealtimeModel(
model_name="qwen3-omni-flash-realtime",
api_key=os.getenv("DASHSCOPE_API_KEY"),
)
)
queue = asyncio.Queue()
await agent.start(queue)
# handle the outgoing messages from the agent in another asyncio
# task
...
await agent.stop()
"""
def __init__(
self,
name: str,
sys_prompt: str,
model: RealtimeModelBase,
toolkit: Toolkit | None = None,
) -> None:
"""Initialize the RealtimeAgent class.
Args:
name (`str`):
The name of the agent.
sys_prompt (`str`):
The system prompt of the agent.
model (`RealtimeModelBase`):
The realtime model used by the agent.
toolkit (`Toolkit | None`, optional):
A `Toolkit` object that contains the tool functions. If not
provided, a default empty `Toolkit` will be created.
"""
super().__init__()
self.id = shortuuid.uuid()
self.name = name
self.sys_prompt = sys_prompt
self.model = model
self.toolkit = toolkit
# A queue to handle the incoming events from other agents or the
# frontend.
self._incoming_queue = Queue()
self._external_event_handling_task = None
# The queue to gather model responses.
self._model_response_queue = Queue()
self._model_response_handling_task = None
async def start(self, outgoing_queue: Queue) -> None:
"""Establish a connection for real-time interaction.
Args:
outgoing_queue (`Queue`):
The queue to push messages to the frontend and other agents.
"""
# Start the realtime model connection.
await self.model.connect(
self._model_response_queue,
instructions=self.sys_prompt,
tools=self.toolkit.get_json_schemas() if self.toolkit else None,
)
# Start the forwarding loop.
self._external_event_handling_task = asyncio.create_task(
self._forward_loop(),
)
# Start the response handling loop.
self._model_response_handling_task = asyncio.create_task(
self._model_response_loop(outgoing_queue),
)
async def stop(self) -> None:
"""Close the connection."""
if not self._external_event_handling_task.done():
self._external_event_handling_task.cancel()
await self.model.disconnect()
async def _forward_loop(self) -> None:
"""The loop to forward messages from other agents or the frontend to
the realtime model for processing.
outside ==> agent ==> realtime model
"""
logger.info(
"Agent '%s' begins the loops to receive external events",
self.name,
)
while True:
event = await self._incoming_queue.get()
match event:
# Only handle the events that we need
case ServerEvents.AgentResponseAudioDeltaEvent() as event:
# Convert the sample rate to the required format by the
# model
receive_rate = event.format.rate
if self.model.input_sample_rate != receive_rate:
delta = _resample_pcm_delta(
event.delta,
receive_rate,
self.model.input_sample_rate,
)
else:
delta = event.delta
await self.model.send(
AudioBlock(
type="audio",
source=Base64Source(
type="base64",
media_type=event.format.type,
data=delta,
),
),
)
case ServerEvents.AgentResponseAudioDoneEvent():
# Send a silence audio block to indicate the end of audio
pass
case ClientEvents.ClientAudioAppendEvent() as event:
# Construct media_type from format info
# format contains: {"sample_rate": 16000, "encoding":
# "pcm16"}
# encoding = event.format.get("encoding", "pcm16")
# media_type = (
# f"audio/{encoding.replace('16', '')}"
# if "pcm" in encoding
# else "audio/pcm"
# )
await self.model.send(
AudioBlock(
type="audio",
source=Base64Source(
type="base64",
media_type=event.format.type,
data=event.audio,
),
),
)
case ClientEvents.ClientTextAppendEvent() as event:
await self.model.send(
TextBlock(
type="text",
text=event.text,
),
)
case ClientEvents.ClientImageAppendEvent() as event:
# Construct media_type from format info
media_type = event.format.get("type", "image/jpeg")
await self.model.send(
ImageBlock(
type="image",
source=Base64Source(
type="base64",
media_type=media_type,
data=event.image,
),
),
)
async def _model_response_loop(self, outgoing_queue: Queue) -> None:
"""The loop to handle model responses and forward them to the
frontend and other agents.
realtime model ==> agent ==> outside
Args:
outgoing_queue (`Queue`):
The queue to push messages to the frontend and other agents.
"""
while True:
model_event = await self._model_response_queue.get()
agent_kwargs = {"agent_id": self.id, "agent_name": self.name}
agent_event = None
match model_event:
# The events that can be converted from model events to agent
# events directly
case (
ModelEvents.ModelResponseCreatedEvent()
| ModelEvents.ModelResponseDoneEvent()
| ModelEvents.ModelResponseAudioDeltaEvent()
| ModelEvents.ModelResponseAudioDoneEvent()
| ModelEvents.ModelResponseAudioTranscriptDeltaEvent()
| ModelEvents.ModelResponseAudioTranscriptDoneEvent()
| ModelEvents.ModelResponseToolUseDeltaEvent()
| ModelEvents.ModelInputTranscriptionDeltaEvent()
| ModelEvents.ModelInputTranscriptionDoneEvent()
| ModelEvents.ModelInputStartedEvent()
| ModelEvents.ModelInputDoneEvent()
| ModelEvents.ModelErrorEvent()
) as event:
# Directly map the model event to agent event
agent_event = ServerEvents.from_model_event(
event,
**agent_kwargs,
)
# The events that need special handling
case ModelEvents.ModelSessionCreatedEvent():
# Send the agent ready event to the outside.
agent_event = ServerEvents.AgentReadyEvent(**agent_kwargs)
case ModelEvents.ModelSessionEndedEvent():
# Send the agent session ended event to the outside.
agent_event = ServerEvents.AgentEndedEvent(**agent_kwargs)
# The tool use done that requires executing the tool
# Such event may generate multiple outgoing events:
# 1. Tool use done event
# 2. Tool result event
case ModelEvents.ModelResponseToolUseDoneEvent() as event:
# Send the tool use done event immediately
done_event = ServerEvents.AgentResponseToolUseDoneEvent(
response_id=event.response_id,
item_id=event.item_id,
tool_use=event.tool_use,
**agent_kwargs,
)
# Directly put the done event to the outgoing queue
await outgoing_queue.put(done_event)
# Then execute the tool call using accumulated arguments
if self.toolkit:
# Execute the tool call asynchronously
asyncio.create_task(
self._acting(
tool_use=event.tool_use,
outgoing_queue=outgoing_queue,
),
)
case _:
logger.debug(
"Unknown model event type: %s",
type(model_event),
)
if agent_event is not None:
# Put the processed response to the outgoing queue.
await outgoing_queue.put(agent_event)
async def handle_input(
self,
event: ClientEvents.EventBase | ServerEvents.EventBase,
) -> None:
"""Handle the input message from the frontend or the other agents.
Args:
event (`ClientEvents.EventBase | ServerEvents.EventBase`):
The input event from the frontend or other agents.
"""
await self._incoming_queue.put(event)
async def _acting(
self,
tool_use: ToolUseBlock,
outgoing_queue: Queue,
) -> None:
"""Execute the tool call and send the result back to the outside (
frontend or other agents).
Args:
tool_use (`ToolUseBlock`):
The tool use block containing the tool call information.
outgoing_queue (`Queue`):
The queue to push messages to the frontend and other agents.
"""
if not self.toolkit:
return
res = await self.toolkit.call_tool_function(tool_use)
last_chunk = None
async for chunk in res:
last_chunk = chunk
if last_chunk:
tool_result_block = ToolResultBlock(
type="tool_result",
id=tool_use.get("id"),
name=tool_use.get("name"),
output=last_chunk.content,
)
# Send the tool result back to the model
await self.model.send(tool_result_block)
# Also send event to frontend/other agents
outgoing_event = ServerEvents.AgentResponseToolResultEvent(
tool_result=tool_result_block,
agent_id=self.id,
agent_name=self.name,
)
await outgoing_queue.put(outgoing_event)

View File

@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
"""The user agent class."""
from typing import Type, Any
from pydantic import BaseModel
from ._agent_base import AgentBase
from ._user_input import UserInputBase, TerminalUserInput
from ..message import Msg
class UserAgent(AgentBase):
"""The class for user interaction, allowing developers to handle the user
input from different sources, such as web UI, cli, and other interfaces.
"""
_input_method: UserInputBase = TerminalUserInput()
"""The user input method, can be overridden by calling the
`register_instance/class_input_method` function."""
def __init__(
self,
name: str,
) -> None:
"""Initialize the user agent with a name."""
super().__init__()
self.name = name
async def reply(
self,
msg: Msg | list[Msg] | None = None,
structured_model: Type[BaseModel] | None = None,
) -> Msg:
"""Receive input message(s) and generate a reply message from the user.
Args:
msg (`Msg | list[Msg] | None`, defaults to `None`):
The message(s) to be replied. If `None`, the agent will wait
for user input.
structured_model (`Type[BaseModel] | None`, defaults to `None`):
A child class of `pydantic.BaseModel` that defines the
structured output format. If provided, the user will be
prompted to fill in the required fields.
Returns:
`Msg`:
The reply message generated by the user.
"""
# Get the input from the specified input method.
input_data = await self._input_method(
agent_id=self.id,
agent_name=self.name,
structured_model=structured_model,
)
blocks_input = input_data.blocks_input
if (
blocks_input
and len(blocks_input) == 1
and blocks_input[0].get("type") == "text"
):
# Turn blocks_input into a string if only one text block exists
blocks_input = blocks_input[0].get("text")
msg = Msg(
self.name,
content=blocks_input,
role="user",
metadata=input_data.structured_input,
)
await self.print(msg)
return msg
def override_instance_input_method(
self,
input_method: UserInputBase,
) -> None:
"""Override the input method of the current UserAgent instance.
Args:
input_method (`UserInputBase`):
The callable input method, which should be an object of a
class that inherits from `UserInputBase`.
"""
if not isinstance(input_method, UserInputBase):
raise ValueError(
f"The input method should be an instance of the child class "
f"of `UserInputBase`, but got {type(input_method)} instead.",
)
self._input_method = input_method
@classmethod
def override_class_input_method(
cls,
input_method: UserInputBase,
) -> None:
"""Override the input method of the current UserAgent class.
Args:
input_method (`UserInputBase`):
The callable input method, which should be an object of a
class that inherits from `UserInputBase`.
"""
if not isinstance(input_method, UserInputBase):
raise ValueError(
f"The input method should be an instance of the child class "
f"of `UserInputBase`, but got {type(input_method)} instead.",
)
cls._input_method = input_method
async def handle_interrupt(
self,
*args: Any,
**kwargs: Any,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else."""
raise NotImplementedError(
f"The handle_interrupt function is not implemented in "
f"{self.__class__.__name__}",
)
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Observe the message(s) from the other agents or the environment."""

View File

@@ -0,0 +1,415 @@
# -*- coding: utf-8 -*-
"""The user input related classes."""
import json.decoder
import time
from abc import abstractmethod
from dataclasses import dataclass
from queue import Queue
from threading import Event
from typing import Any, Type, List
import jsonschema
import requests
import shortuuid
import socketio
from pydantic import BaseModel
import json5
from .. import _config
from .._logging import logger
from ..message import (
TextBlock,
VideoBlock,
AudioBlock,
ImageBlock,
)
@dataclass
class UserInputData:
"""The user input data."""
blocks_input: List[TextBlock | ImageBlock | AudioBlock | VideoBlock] = None
"""The text input from the user"""
structured_input: dict[str, Any] | None = None
"""The structured input from the user"""
class UserInputBase:
"""The base class used to handle the user input from different sources."""
@abstractmethod
async def __call__(
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> UserInputData:
"""The user input method, which returns the user input and the
required structured data.
Args:
agent_id (`str`):
The agent identifier.
agent_name (`str`):
The agent name.
structured_model (`Type[BaseModel] | None`, optional):
A base model class that defines the structured input format.
Returns:
`UserInputData`:
The user input data.
"""
class TerminalUserInput(UserInputBase):
"""The terminal user input."""
def __init__(self, input_hint: str = "User Input: ") -> None:
"""Initialize the terminal user input with a hint."""
self.input_hint = input_hint
async def __call__(
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> UserInputData:
"""Handle the user input from the terminal.
Args:
agent_id (`str`):
The agent identifier.
agent_name (`str`):
The agent name.
structured_model (`Type[BaseModel] | None`, optional):
A base model class that defines the structured input format.
Returns:
`UserInputData`:
The user input data.
"""
text_input = (
input(self.input_hint)
.encode("utf-8", errors="ignore")
.decode("utf-8")
)
structured_input = None
if structured_model is not None:
structured_input = {}
json_schema = structured_model.model_json_schema()
required = json_schema.get("required", [])
print("Structured input (press Enter to skip for optional):)")
for key, item in json_schema.get("properties").items():
requirements = {**item}
requirements.pop("title")
while True:
res = input(f"\t{key} ({requirements}): ")
if res == "":
if key in required:
print(f"Key {key} is required.")
continue
res = item.get("default", None)
if item.get("type").lower() == "integer":
try:
res = json5.loads(res)
except json.decoder.JSONDecodeError as e:
print(
"\033[31mInvalid input with error:\n"
"```\n"
f"{e}\n"
"```\033[0m",
)
continue
try:
jsonschema.validate(res, item)
structured_input[key] = res
break
except jsonschema.ValidationError as e:
print(
f"\033[31mValidation error:\n```\n{e}\n```\033[0m",
)
time.sleep(0.5)
return UserInputData(
blocks_input=[TextBlock(type="text", text=text_input)],
structured_input=structured_input,
)
class StudioUserInput(UserInputBase):
"""The class that host the user input on the AgentScope Studio."""
_websocket_namespace: str = "/python"
def __init__(
self,
studio_url: str,
run_id: str,
max_retries: int = 3,
reconnect_attempts: int = 3,
reconnection_delay: int = 1,
reconnection_delay_max: int = 5,
) -> None:
"""Initialize the StudioUserInput object.
Args:
studio_url (`str`):
The URL of the AgentScope Studio.
run_id (`str`):
The current run identity.
max_retries (`int`, defaults to `3`):
The maximum number of retries to get user input.
"""
self._is_connected = False
self._is_reconnecting = False
self.studio_url = studio_url
self.run_id = run_id
self.max_retries = max_retries
# Init Websocket
self.sio = socketio.Client(
reconnection=True,
reconnection_attempts=reconnect_attempts,
reconnection_delay=reconnection_delay,
reconnection_delay_max=reconnection_delay_max,
)
self.input_queues = {}
self.input_events = {}
@self.sio.on("connect", namespace=self._websocket_namespace)
def on_connect() -> None:
self._is_connected = True
logger.info(
'Connected to AgentScope Studio at "%s" with '
'run name "%s".',
self.studio_url,
run_id,
)
logger.info(
"View the run at: %s/projects/%s",
self.studio_url,
_config.project,
)
@self.sio.on("disconnect", namespace=self._websocket_namespace)
def on_disconnect() -> None:
self._is_connected = False
logger.info(
"Disconnected from AgentScope Studio at %s",
self.studio_url,
)
@self.sio.on("reconnect", namespace=self._websocket_namespace)
def on_reconnect(attempt_number: int) -> None:
self._is_connected = True
self._is_reconnecting = False
logger.info(
"Reconnected to AgentScope Studio at %s with run_id %s after "
"%d attempts",
self.studio_url,
self.run_id,
attempt_number,
)
@self.sio.on("reconnect_attempt", namespace=self._websocket_namespace)
def on_reconnect_attempt(attempt_number: int) -> None:
self._is_reconnecting = True
logger.info(
"Attempting to reconnect to AgentScope Studio at %s "
"(attempt %d)",
self.studio_url,
attempt_number,
)
@self.sio.on("reconnect_failed", namespace=self._websocket_namespace)
def on_reconnect_failed() -> None:
self._is_reconnecting = False
logger.error(
"Failed to reconnect to AgentScope Studio at %s",
self.studio_url,
)
@self.sio.on("reconnect_error", namespace=self._websocket_namespace)
def on_reconnect_error(error: Any) -> None:
logger.error(
"Error while reconnecting to AgentScope Studio at %s: %s",
self.studio_url,
str(error),
)
# The AgentScope Studio backend send the "sendUserInput" event to
# the current python run
@self.sio.on("forwardUserInput", namespace=self._websocket_namespace)
def receive_user_input(
request_id: str,
blocks_input: List[
TextBlock | ImageBlock | AudioBlock | VideoBlock
],
structured_input: dict[str, Any],
) -> None:
if request_id in self.input_queues:
self.input_queues[request_id].put(
UserInputData(
blocks_input=blocks_input,
structured_input=structured_input,
),
)
self.input_events[request_id].set()
try:
self.sio.connect(
f"{self.studio_url}",
namespaces=["/python"],
auth={"run_id": self.run_id},
)
except Exception as e:
raise RuntimeError(
f"Failed to connect to AgentScope Studio at {self.studio_url}",
) from e
def _ensure_connected(
self,
timeout: float = 30.0,
check_interval: float = 5.0,
) -> None:
"""Ensure the connection is established or wait for reconnection.
Args:
timeout (`float`):
Maximum time to wait for reconnection in seconds. Defaults
to 30.0.
check_interval (`float`):
Interval between connection checks in seconds. Defaults to 1.0.
Raises:
`RuntimeError`:
If connection cannot be established within timeout.
"""
if self._is_connected:
return
if self._is_reconnecting:
start_time = time.time()
while self._is_reconnecting:
# Check timeout
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
raise RuntimeError(
f"Reconnection timeout after {elapsed_time} seconds",
)
# Log status
logger.info(
"Waiting for reconnection... (%.1fs / %.1fs)",
elapsed_time,
timeout,
)
# Wait for next check
time.sleep(check_interval)
# After reconnection attempt completed, check final status
if self._is_connected:
return
# Not connected and not reconnecting
raise RuntimeError(
f"Not connected to AgentScope Studio at {self.studio_url}.",
)
async def __call__( # type: ignore[override]
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_model: Type[BaseModel] | None = None,
) -> UserInputData:
"""Get the user input from AgentScope Studio.
Args:
agent_id (`str`):
The identity of the agent.
agent_name (`str`):
The name of the agent.
structured_model (`Type[BaseModel] | None`, optional):
The base model class of the structured input.
Raises:
`RuntimeError`:
Failed to get user input from AgentScope Studio.
Returns:
`UserInputData`:
The user input.
"""
self._ensure_connected()
request_id = shortuuid.uuid()
self.input_queues[request_id] = Queue()
self.input_events[request_id] = Event()
if structured_model is None:
structured_input = None
else:
structured_input = structured_model.model_json_schema()
n_retry = 0
while True:
try:
response = requests.post(
f"{self.studio_url}/trpc/requestUserInput",
json={
"requestId": request_id,
"runId": self.run_id,
"agentId": agent_id,
"agentName": agent_name,
"structuredInput": structured_input,
},
)
response.raise_for_status()
break
except Exception as e:
if n_retry < self.max_retries:
n_retry += 1
continue
raise RuntimeError(
"Failed to get user input from AgentScope Studio",
) from e
try:
self.input_events[request_id].wait()
response_data = self.input_queues[request_id].get()
return response_data
finally:
self.input_queues.pop(request_id, None)
self.input_events.pop(request_id, None)
def __del__(self) -> None:
"""Cleanup socket connection when object it destroyed"""
try:
self.sio.disconnect()
except Exception as e:
logger.error(
"Failed to disconnect from AgentScope Studio at %s: %s",
self.studio_url,
str(e),
)

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""Utils for agents in agentscope."""
from typing import Any
class _AsyncNullContext:
"""An async null context manager."""
async def __aenter__(self) -> None:
return None
async def __aexit__(
self,
exc_type: Any,
exc_val: Any,
exc_tb: Any,
) -> None:
return None

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""The embedding module in agentscope."""
from ._embedding_base import EmbeddingModelBase
from ._embedding_usage import EmbeddingUsage
from ._embedding_response import EmbeddingResponse
from ._dashscope_embedding import DashScopeTextEmbedding
from ._dashscope_multimodal_embedding import DashScopeMultiModalEmbedding
from ._openai_embedding import OpenAITextEmbedding
from ._gemini_embedding import GeminiTextEmbedding
from ._ollama_embedding import OllamaTextEmbedding
from ._cache_base import EmbeddingCacheBase
from ._file_cache import FileEmbeddingCache
__all__ = [
"EmbeddingModelBase",
"EmbeddingUsage",
"EmbeddingResponse",
"DashScopeTextEmbedding",
"DashScopeMultiModalEmbedding",
"OpenAITextEmbedding",
"GeminiTextEmbedding",
"OllamaTextEmbedding",
"EmbeddingCacheBase",
"FileEmbeddingCache",
]

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""The embedding cache base class."""
from abc import abstractmethod
from typing import List, Any
from ..types import (
JSONSerializableObject,
Embedding,
)
class EmbeddingCacheBase:
"""Base class for embedding caches, which is responsible for storing and
retrieving embeddings."""
@abstractmethod
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings.
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
@abstractmethod
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not
found, return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings.
"""
@abstractmethod
async def remove(
self,
identifier: JSONSerializableObject,
) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifier to remove the embeddings.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear all cached embeddings."""

View File

@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""The dashscope embedding module in agentscope."""
from datetime import datetime
from typing import Any, List, Literal
from ._cache_base import EmbeddingCacheBase
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._embedding_base import EmbeddingModelBase
from .._logging import logger
from ..message import TextBlock
class DashScopeTextEmbedding(EmbeddingModelBase):
"""DashScope text embedding API class.
.. note:: From the `official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_:
- The max batch size that DashScope text embedding API
supports is 10 for `text-embedding-v4` and `text-embedding-v3` models, and
25 for `text-embedding-v2` and `text-embedding-v1` models.
- The max token limit for a single input is 8192 tokens for `v4` and `v3`
models, and 2048 tokens for `v2` and `v1` models.
"""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
) -> None:
"""Initialize the DashScope text embedding model class.
Args:
api_key (`str`):
The dashscope API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector, refer to the
`official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_
for more details.
embedding_cache (`EmbeddingCacheBase`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
super().__init__(model_name, dimensions)
self.api_key = api_key
self.embedding_cache = embedding_cache
self.batch_size_limit = 10
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
"""Call the DashScope embedding API by the given keyword arguments."""
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
import dashscope
start_time = datetime.now()
response = dashscope.embeddings.TextEmbedding.call(
api_key=self.api_key,
**kwargs,
)
time = (datetime.now() - start_time).total_seconds()
if response.status_code != 200:
raise RuntimeError(
f"Failed to get embedding from DashScope API: {response}",
)
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[
_["embedding"] for _ in response.output["embeddings"]
],
)
return EmbeddingResponse(
embeddings=[_["embedding"] for _ in response.output["embeddings"]],
usage=EmbeddingUsage(
tokens=response.usage["total_tokens"],
time=time,
),
)
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the DashScope embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
if len(gather_text) > self.batch_size_limit:
logger.info(
"The input texts (%d) will be embedded with %d API calls due "
f"to the batch size limit of {self.batch_size_limit} for "
f"DashScope embedding API.",
len(gather_text),
(len(gather_text) + self.batch_size_limit - 1)
// self.batch_size_limit,
)
# Handle the batch size limit for DashScope embedding API
collected_embeddings = []
collected_time = 0.0
collected_tokens = 0
collected_source: Literal["cache", "api"] = "cache"
for _ in range(0, len(gather_text), self.batch_size_limit):
batch_texts = gather_text[_ : _ + self.batch_size_limit]
batch_kwargs = {
"input": batch_texts,
"model": self.model_name,
"dimension": self.dimensions,
**kwargs,
}
res = await self._call_api(batch_kwargs)
collected_embeddings.extend(res.embeddings)
collected_time += res.usage.time
if res.usage.tokens:
collected_tokens += res.usage.tokens
if res.source == "api":
collected_source = "api"
return EmbeddingResponse(
embeddings=collected_embeddings,
usage=EmbeddingUsage(
tokens=collected_tokens,
time=collected_time,
),
source=collected_source,
)

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-
"""The dashscope multimodal embedding model in agentscope."""
from datetime import datetime
from typing import Any, Literal
from ._cache_base import EmbeddingCacheBase
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._embedding_base import EmbeddingModelBase
from ..message import (
VideoBlock,
ImageBlock,
TextBlock,
)
class DashScopeMultiModalEmbedding(EmbeddingModelBase):
"""The DashScope multimodal embedding API, supporting text, image and
video embedding."""
supported_modalities: list[str] = ["text", "image", "video"]
"""This class supports text, image and video input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int | None = None,
embedding_cache: EmbeddingCacheBase | None = None,
) -> None:
"""Initialize the DashScope multimodal embedding model class.
Args:
api_key (`str`):
The dashscope API key.
model_name (`str`):
The name of the embedding model, e.g. "multimodal-embedding-
v1", "tongyi-embedding-vision-plus".
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector, refer to the
`official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712517>`_
for more details.
embedding_cache (`EmbeddingCacheBase`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
path_doc = (
"https://bailian.console.aliyun.com/?tab=api#/api/?type=model&"
"url=2712517"
)
self.batch_size_limit = 1
if model_name.startswith("tongyi-embedding-vision-plus"):
self.batch_size_limit = 8
if dimensions is None:
dimensions = 1152
elif dimensions != 1152:
raise ValueError(
f"The dimension of model {model_name} must be 1152, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
if model_name.startswith("tongyi-embedding-vision-flash"):
self.batch_size_limit = 8
if dimensions is None:
dimensions = 768
elif dimensions != 768:
raise ValueError(
f"The dimension of model {model_name} must be 768, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
if model_name.startswith("multimodal-embedding-v"):
if dimensions is None:
dimensions = 1024
elif dimensions != 1024:
raise ValueError(
f"The dimension of model {model_name} must be 1024, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
refined_dimensions: int = 1024
if dimensions is not None:
refined_dimensions = dimensions
super().__init__(model_name, refined_dimensions)
self.api_key = api_key
self.embedding_cache = embedding_cache
async def __call__(
self,
inputs: list[TextBlock | ImageBlock | VideoBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the DashScope multimodal embedding API, which accepts text,
image, and video data.
Args:
inputs (`list[TextBlock | ImageBlock | VideoBlock]`):
The input data to be embedded. It can be a list of text,
image, and video blocks.
Returns:
`EmbeddingResponse`:
The embedding response object, which contains the embeddings
and usage information.
"""
# check data type
formatted_data = []
for _ in inputs:
if (
not isinstance(_, dict)
or "type" not in _
or _["type"]
not in [
"text",
"image",
"video",
]
):
raise ValueError(
f"Invalid data : {_}. It should be a list of "
"TextBlock, ImageBlock, or VideoBlock.",
)
if (
_["type"] == "video"
and _.get("source", {}).get("type") != "url"
):
raise ValueError(
f"The multimodal embedding API only supports URL input "
f"for video data, but got {_}.",
)
if _["type"] == "text":
assert "text" in _, (
f"Invalid text block: {_}. It should contain a "
f"'text' field.",
)
formatted_data.append({"text": _["text"]})
elif _["type"] == "video":
formatted_data.append({"video": _["source"]["url"]})
elif (
_["type"] == "image"
and "source" in _
and _["source"].get("type") in ["base64", "url"]
):
typ = _["source"]["type"]
if typ == "base64":
formatted_data.append(
{
"image": f'data:{_["source"]["media_type"]};'
f'base64,{_["source"]["data"]}',
},
)
elif typ == "url":
formatted_data.append(
{"image": _["source"]["url"]},
)
else:
raise ValueError(
f"Invalid block {_}. It should be a valid TextBlock, "
f"ImageBlock, or VideoBlock.",
)
# Handle the batch size limit of the DashScope multimodal embedding API
collected_embeddings = []
collected_time = 0.0
collected_tokens = 0
collected_source: Literal["cache", "api"] = "cache"
for _ in range(0, len(formatted_data), self.batch_size_limit):
batch_data = formatted_data[_ : _ + self.batch_size_limit]
batch_kwargs = {
"input": batch_data,
"model": self.model_name,
**kwargs,
}
res = await self._call_api(batch_kwargs)
collected_embeddings.extend(res.embeddings)
collected_time += res.usage.time
if res.usage.tokens:
collected_tokens += res.usage.tokens
if res.source == "api":
collected_source = "api"
return EmbeddingResponse(
embeddings=collected_embeddings,
usage=EmbeddingUsage(
tokens=collected_tokens,
time=collected_time,
),
source=collected_source,
)
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
"""
Call the DashScope multimodal embedding API by the given arguments.
"""
# Search in cache first
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
import dashscope
kwargs["api_key"] = self.api_key
start_time = datetime.now()
res = dashscope.MultiModalEmbedding.call(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if res.status_code != 200:
raise RuntimeError(
f"Failed to get embedding from DashScope API: {res}",
)
return EmbeddingResponse(
embeddings=[_["embedding"] for _ in res.output["embeddings"]],
usage=EmbeddingUsage(
tokens=res.usage.get(
"image_tokens",
0,
)
+ res.usage.get(
"input_tokens",
0,
),
time=time,
),
source="api",
)

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
"""The embedding model base class."""
from typing import Any
from ._embedding_response import EmbeddingResponse
class EmbeddingModelBase:
"""Base class for embedding models."""
model_name: str
"""The embedding model name"""
supported_modalities: list[str]
"""The supported data modalities, e.g. "text", "image", "video"."""
dimensions: int
"""The dimensions of the embedding vector."""
def __init__(
self,
model_name: str,
dimensions: int,
) -> None:
"""Initialize the embedding model base class.
Args:
model_name (`str`):
The name of the embedding model.
dimensions (`int`):
The dimension of the embedding vector.
"""
self.model_name = model_name
self.dimensions = dimensions
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the embedding API with the given arguments."""
raise NotImplementedError(
f"The {self.__class__.__name__} class does not implement "
f"the __call__ method.",
)

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""The embedding response class."""
from dataclasses import dataclass, field
from typing import Literal, List
from ._embedding_usage import EmbeddingUsage
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..types import Embedding
@dataclass
class EmbeddingResponse(DictMixin):
"""The embedding response class."""
embeddings: List[Embedding]
"""The embedding data"""
id: str = field(default_factory=lambda: _get_timestamp(True))
"""The identity of the embedding response"""
created_at: str = field(default_factory=_get_timestamp)
"""The timestamp of the embedding response creation"""
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
"""The type of the response, must be `embedding`."""
usage: EmbeddingUsage | None = field(default_factory=lambda: None)
"""The usage of the embedding model API invocation, if available."""
source: Literal["cache", "api"] = field(default_factory=lambda: "api")
"""If the response comes from the cache or the API."""

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""The embedding usage class in agentscope."""
from dataclasses import dataclass, field
from typing import Literal
from .._utils._mixin import DictMixin
@dataclass
class EmbeddingUsage(DictMixin):
"""The usage of an embedding model API invocation."""
time: float
"""The time used in seconds."""
tokens: int | None = field(default_factory=lambda: None)
"""The number of tokens used, if available."""
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
"""The type of the usage, must be `embedding`."""

View File

@@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
"""A file embedding cache implementation for storing and retrieving
embeddings in binary files."""
import hashlib
import json
import os
from typing import Any, List
import numpy as np
from ._cache_base import EmbeddingCacheBase
from .._logging import logger
from ..types import (
Embedding,
JSONSerializableObject,
)
class FileEmbeddingCache(EmbeddingCacheBase):
"""The embedding cache class that stores each embeddings vector in
binary files."""
def __init__(
self,
cache_dir: str = "./.cache/embeddings",
max_file_number: int | None = None,
max_cache_size: int | None = None,
) -> None:
"""Initialize the file embedding cache class.
Args:
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
The directory to store the embedding files.
max_file_number (`int | None`, defaults to `None`):
The maximum number of files to keep in the cache directory. If
exceeded, the oldest files will be removed.
max_cache_size (`int | None`, defaults to `None`):
The maximum size of the cache directory in MB. If exceeded,
the oldest files will be removed until the size is within the
limit.
"""
self._cache_dir = os.path.abspath(cache_dir)
self.max_file_number = max_file_number
self.max_cache_size = max_cache_size
@property
def cache_dir(self) -> str:
"""The cache directory where the embedding files are stored."""
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir, exist_ok=True)
return self._cache_dir
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
if not os.path.isfile(path_file):
raise RuntimeError(
f"Path {path_file} exists but is not a file.",
)
if overwrite:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
else:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not found,
return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
return np.load(os.path.join(self.cache_dir, filename)).tolist()
return None
async def remove(self, identifier: JSONSerializableObject) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifiers to remove the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
os.remove(path_file)
else:
raise FileNotFoundError(f"File {path_file} does not exist.")
async def clear(self) -> None:
"""Clear the cache directory by removing all files."""
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
os.remove(os.path.join(self.cache_dir, filename))
def _get_cache_size(self) -> float:
"""Get the current size of the cache directory in MB."""
total_size = 0
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
path_file = os.path.join(self.cache_dir, filename)
if os.path.isfile(path_file):
total_size += os.path.getsize(path_file)
return total_size / (1024.0 * 1024.0)
@staticmethod
def _get_filename(identifier: JSONSerializableObject) -> str:
"""Generate a filename based on the identifier."""
json_str = json.dumps(identifier, ensure_ascii=False)
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
async def _maintain_cache_dir(self) -> None:
"""Maintain the cache directory by removing old files if the number of
files exceeds the maximum limit or if the cache size exceeds the
maximum size."""
files = [
(_.name, _.stat().st_mtime)
for _ in os.scandir(self.cache_dir)
if _.is_file() and _.name.endswith(".npy")
]
files.sort(key=lambda x: x[1])
if self.max_file_number and len(files) > self.max_file_number:
for file_name, _ in files[: 0 - self.max_file_number]:
os.remove(os.path.join(self.cache_dir, file_name))
logger.info(
"Remove cached embedding file %s for limited number "
"of files (%d).",
file_name,
self.max_file_number,
)
files = files[0 - self.max_file_number :]
if (
self.max_cache_size is not None
and self._get_cache_size() > self.max_cache_size
):
removed_files = []
for filename, _ in files:
os.remove(os.path.join(self.cache_dir, filename))
removed_files.append(filename)
if self._get_cache_size() <= self.max_cache_size:
break
if removed_files:
logger.info(
"Remove %d cached embedding file(s) for limited "
"cache size (%d MB).",
len(removed_files),
self.max_cache_size,
)

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""The gemini text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class GeminiTextEmbedding(EmbeddingModelBase):
"""The Gemini text embedding model."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 3072,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Gemini text embedding model class.
Args:
api_key (`str`):
The Gemini API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 3072):
The dimension of the embedding vector, refer to the
`official documentation
<https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_
for more details.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
from google import genai
super().__init__(model_name, dimensions)
self.client = genai.Client(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""The Gemini embedding API call.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
# TODO: handle the batch size limit
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"model": self.model_name,
"contents": gather_text,
"config": kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = self.client.models.embed_content(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.values for _ in response.embeddings],
)
return EmbeddingResponse(
embeddings=[_.values for _ in response.embeddings],
usage=EmbeddingUsage(
time=time,
),
)

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""The ollama text embedding model class."""
from datetime import datetime
from typing import List, Any
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ..embedding import EmbeddingModelBase
from ..message import TextBlock
class OllamaTextEmbedding(EmbeddingModelBase):
"""The Ollama embedding model."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
model_name: str,
dimensions: int,
host: str | None = None,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Ollama text embedding model class.
Args:
model_name (`str`):
The name of the embedding model.
dimensions (`int`):
The dimension of the embedding vector, the parameter should be
provided according to the model used.
host (`str | None`, defaults to `None`):
The host URL for the Ollama API.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
import ollama
super().__init__(model_name, dimensions)
self.client = ollama.AsyncClient(host=host, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the Ollama embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embed(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=response.embeddings,
)
return EmbeddingResponse(
embeddings=response.embeddings,
usage=EmbeddingUsage(
time=time,
),
)

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""The OpenAI text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class OpenAITextEmbedding(EmbeddingModelBase):
"""OpenAI text embedding model class."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the OpenAI text embedding model class.
Args:
api_key (`str`):
The OpenAI API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
# TODO: handle batch size limit and token limit
"""
import openai
super().__init__(model_name, dimensions)
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the OpenAI embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
"encoding_format": "float",
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embeddings.create(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.embedding for _ in response.data],
)
return EmbeddingResponse(
embeddings=[_.embedding for _ in response.data],
usage=EmbeddingUsage(
tokens=response.usage.total_tokens,
time=time,
),
)

View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""The evaluation module in AgentScope."""
from ._evaluator import (
EvaluatorBase,
RayEvaluator,
GeneralEvaluator,
)
from ._metric_base import (
MetricBase,
MetricResult,
MetricType,
)
from ._task import Task
from ._solution import SolutionOutput
from ._benchmark_base import BenchmarkBase
from ._evaluator_storage import (
EvaluatorStorageBase,
FileEvaluatorStorage,
)
from ._ace_benchmark import (
ACEBenchmark,
ACEAccuracy,
ACEProcessAccuracy,
ACEPhone,
)
__all__ = [
"BenchmarkBase",
"EvaluatorBase",
"RayEvaluator",
"GeneralEvaluator",
"MetricBase",
"MetricResult",
"MetricType",
"EvaluatorStorageBase",
"FileEvaluatorStorage",
"Task",
"SolutionOutput",
"ACEBenchmark",
"ACEAccuracy",
"ACEProcessAccuracy",
"ACEPhone",
]

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The ACE benchmark related implementations in AgentScope."""
from ._ace_benchmark import ACEBenchmark
from ._ace_metric import (
ACEAccuracy,
ACEProcessAccuracy,
)
from ._ace_tools_zh import ACEPhone
__all__ = [
"ACEBenchmark",
"ACEPhone",
"ACEAccuracy",
"ACEProcessAccuracy",
]

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
"""The ACE benchmark class in agentscope. The code is implemented with
reference to the `ACEBench <https://github.com/ACEBench/ACEBench>`_
under the MIT license."""
import json
import os
from typing import Generator
import json5
import requests
from tqdm import tqdm
from ._ace_metric import ACEAccuracy, ACEProcessAccuracy
from ._ace_tools_zh import ACEPhone
from .._benchmark_base import BenchmarkBase
from .._task import Task
class ACEBenchmark(BenchmarkBase):
"""The ACE benchmark for evaluating AI agents."""
data_dir_url: str = (
"https://raw.githubusercontent.com/ACEBench/ACEBench/main/data_all"
)
"""The URL to the data dir"""
data_subdir: list[str] = [
# "data_en", # TODO: enable English version
"data_zh",
]
ground_truth_dir: str = "possible_answer"
data_files: list[str] = [
"data_agent_multi_step.json",
"data_agent_multi_turn.json",
# "data_normal_atom_bool.json",
# "data_normal_atom_enum.json",
# "data_normal_atom_list.json",
# "data_normal_atom_number.json",
# "data_normal_atom_object_deep.json",
# "data_normal_atom_object_short.json",
#
# "data_normal_multi_turn_user_adjust.json",
# "data_normal_multi_turn_user_switch.json",
#
# "data_normal_preference.json",
# "data_normal_similar_api.json",
# "data_normal_single_turn_parallel_function.json",
# "data_normal_single_turn_single_function.json",
#
# "data_special_error_param.json",
# "data_special_incomplete.json",
# "data_special_irrelevant.json",
]
"""The data filenames"""
def __init__(
self,
data_dir: str,
) -> None:
"""Initialize the ACEBenchmark
Args:
data_dir (`str`):
The directory where the dataset is downloaded and saved.
"""
super().__init__(
name="ACEBench",
description="The ACE benchmark for evaluating AI agents.",
)
self.data_dir = os.path.abspath(data_dir)
if os.path.exists(data_dir) and not os.path.isdir(data_dir):
raise RuntimeError(
f"The data_dir `{data_dir}` is not a valid directory path.",
)
os.makedirs(data_dir, exist_ok=True)
if not self._verify_data():
self._download_data()
self.dataset = self._load_data()
def _load_data(self) -> list[dict]:
"""Load the dataset from the data directory."""
dataset = []
for subdir in self.data_subdir:
for filename in self.data_files:
file_path = os.path.join(self.data_dir, subdir, filename)
gt_path = os.path.join(
self.data_dir,
subdir,
self.ground_truth_dir,
filename,
)
gt_dataset = {}
with open(gt_path, "r", encoding="utf-8") as gt_file:
for line in gt_file:
gt_data = json5.loads(line)
gt_dataset[gt_data["id"]] = gt_data
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json5.loads(line)
gt = gt_dataset[data["id"]]
gt.pop("id", None)
data["ground_truth"] = gt["ground_truth"]
data["mile_stone"] = gt["mile_stone"]
data["language"] = subdir.rsplit(
"_",
maxsplit=1,
)[-1]
data["tags"] = {
"language": data["language"],
"category": filename.split(
".",
maxsplit=1,
)[0].removeprefix(
"data_",
),
}
dataset.append(data)
return dataset
def _verify_data(self) -> bool:
"""Verify the data completeness and integrity."""
for subdir in self.data_subdir:
for filename in self.data_files:
file_path = os.path.join(self.data_dir, subdir, filename)
if not os.path.exists(file_path):
return False
gt_path = os.path.join(
self.data_dir,
subdir,
self.ground_truth_dir,
filename,
)
if not os.path.exists(gt_path):
return False
return True
def _download_data(self) -> None:
"""Download the data from the URL"""
for subdir in self.data_subdir:
subdir_path = os.path.join(self.data_dir, subdir)
subdir_gt_path = os.path.join(subdir_path, self.ground_truth_dir)
os.makedirs(subdir_path, exist_ok=True)
os.makedirs(subdir_gt_path, exist_ok=True)
for filename in tqdm(
self.data_files,
desc=f"Downloading {subdir}",
):
response = requests.get(
f"{self.data_dir_url}/{subdir}/{filename}",
)
response.raise_for_status()
with open(os.path.join(subdir_path, filename), "wb") as f:
f.write(response.content)
gt_response = requests.get(
f"{self.data_dir_url}/{subdir}/"
f"{self.ground_truth_dir}/{filename}",
)
gt_response.raise_for_status()
with open(os.path.join(subdir_gt_path, filename), "wb") as f:
f.write(gt_response.content)
@staticmethod
def _data_to_task(item: dict) -> Task:
"""Convert a dataset item to a Task object."""
# Start the simulated phone and load initial configuration
ace_phone = ACEPhone()
ace_phone.load_initial_config(item["initial_config"])
# Obtain tool functions
tools: list[tuple] = []
for function_schema in item["function"]:
name = function_schema["name"]
# Handle the schema differences
formatted_schema = json.loads(
json.dumps(
function_schema,
).replace(
'"type": "dict"',
'"type": "object"',
),
)
tool_function = ace_phone.get_tool_function(name)
tools.append(
(
tool_function,
{
"type": "function",
"function": formatted_schema,
},
),
)
return Task(
id=item["id"],
input=item["question"],
ground_truth={
"state": item["ground_truth"],
"mile_stone": item.get("mile_stone", []),
},
tags=item.get("tags", {}),
metrics=[
ACEAccuracy(item["ground_truth"]),
ACEProcessAccuracy(item["mile_stone"]),
],
metadata={
# The phone is used to extract the final state after finishing
# the task.
"phone": ace_phone,
# The provided tools for this task, used to equip the agent
"tools": tools,
},
)
def __iter__(self) -> Generator[Task, None, None]:
"""Iterate over the benchmark."""
for item in self.dataset:
yield self._data_to_task(item)
def __getitem__(self, index: int) -> Task:
"""Get a task by index."""
return self._data_to_task(self.dataset[index])
def __len__(self) -> int:
"""Get the length of the benchmark."""
return len(self.dataset)

View File

@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
"""The ACE benchmark metric implementations in AgentScope."""
from .._solution import SolutionOutput
from .._metric_base import MetricBase, MetricResult, MetricType
class ACEProcessAccuracy(MetricBase):
"""The ace benchmark process accuracy metric."""
def __init__(
self,
mile_stone: list[str],
) -> None:
"""Initialize the AceBench process accuracy metric."""
super().__init__(
name="process_accuracy",
metric_type=MetricType.NUMERICAL,
description="The AceBench Agent eval process accuracy metric.",
)
self.mile_stone = mile_stone
async def __call__(
self,
solution: SolutionOutput,
) -> MetricResult:
"""Calculate the metric result."""
# Turn the tool use block sequence into ACEBench format
# e.g. func(arg1='dfd', arg2=44)
gathered_trajectory = []
for tool_call in solution.trajectory:
if tool_call.get("type") == "tool_use":
function_name = tool_call.get("name")
kwargs = tool_call.get("input")
gathered_kwargs = []
for key, value in kwargs.items():
if isinstance(value, str):
gathered_kwargs.append(
f"{key}='{value}'",
)
else:
gathered_kwargs.append(
f"{key}={value}",
)
kwargs_str = ", ".join(gathered_kwargs)
gathered_trajectory.append(
f"[{function_name}({kwargs_str})]",
)
for stone in self.mile_stone:
if stone not in gathered_trajectory:
return MetricResult(
name=self.name,
result=0,
message=f"Error: Missing milestone '{stone}' in "
"the given trajectory.",
)
return MetricResult(
name=self.name,
result=1,
message="Success",
)
class ACEAccuracy(MetricBase):
"""The ace benchmark metric"""
def __init__(
self,
state: list[dict],
) -> None:
"""Initialize the _metric object."""
super().__init__(
"accuracy",
MetricType.NUMERICAL,
"The AceBench Agent eval accuracy metric.",
)
self.state = state
async def __call__(
self,
solution: SolutionOutput,
) -> MetricResult:
"""Calculate the metric result."""
# Check if the solution matches the ground truth
if not isinstance(solution.output, list):
raise ValueError("Ground truth state must be a list.")
# Handle the typos in ACEBench dataset
gathered_state = {}
for item in self.state:
for key, value in item.items():
if key.endswith("API"):
key = key.replace("API", "Api")
elif key.endswith("rpi"):
key = key.replace("pi", "Api")
gathered_state[key] = value
gathered_output = {}
for item in solution.output:
for key, value in item.items():
gathered_output[key] = value
if not set(gathered_state.keys()).issubset(gathered_output.keys()):
raise ValueError(
"Missing keys in solution output compared to state, "
f"ground truth keys: {gathered_state.keys()}, "
f"solution keys: {gathered_output.keys()}",
)
for key, value in gathered_state.items():
if value != gathered_output.get(key):
return MetricResult(
name=self.name,
result=0,
message=(
f"Error: Mismatch in key '{key}':"
f"\n{value}\n{gathered_output.get(key)}"
),
)
return MetricResult(
name=self.name,
result=1,
message="Success: All keys match",
)

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
"""The ACEBench simulation tools in AgentScope."""
from ._message_api import MessageApi
from ._travel_api import TravelApi
from ._reminder_api import ReminderApi
from ._food_platform_api import FoodPlatformApi
__all__ = [
"MessageApi",
"TravelApi",
"ReminderApi",
"FoodPlatformApi",
]

View File

@@ -0,0 +1,302 @@
# -*- coding: utf-8 -*-
"""The food platform API in the ACEBench evaluation."""
from ._shared_state import SharedState
class FoodPlatformApi(SharedState):
"""The food platform Api in the ACEBench evaluation."""
tool_functions: list[str] = [
"login_food_platform",
"view_logged_in_users",
"check_balance",
"add_food_delivery_order",
"get_products",
"view_orders",
"search_orders",
]
def __init__(self, shared_state: dict) -> None:
super().__init__(shared_state)
# 设置用户和初始金额
self.users: dict = {
"Eve": {
"user_id": "U100",
"password": "password123",
"balance": 500.0,
},
"Frank": {
"user_id": "U101",
"password": "password456",
"balance": 300.0,
},
"Grace": {
"user_id": "U102",
"password": "password789",
"balance": 150.0,
},
"Helen": {
"user_id": "U103",
"password": "password321",
"balance": 800.0,
},
"Isaac": {
"user_id": "U104",
"password": "password654",
"balance": 400.0,
},
"Jack": {
"user_id": "U105",
"password": "password654",
"balance": 120.0,
},
}
# 设置六个商家及其菜单
self.merchant_list: dict[str, dict] = {
"达美乐": {
"merchant_id": "M100",
"service_type": "Pizza",
"menu": [
{"product": "玛格丽特披萨", "price": 68.0},
{"product": "超级至尊披萨", "price": 88.0},
],
},
"米村拌饭": {
"merchant_id": "M101",
"service_type": "Bibimbap",
"menu": [
{"product": "石锅拌饭", "price": 35.0},
{"product": "韩式牛肉拌饭", "price": 45.0},
],
},
"海底捞": {
"merchant_id": "M102",
"service_type": "Hotpot",
"menu": [
{"product": "牛肉卷", "price": 68.0},
{"product": "海鲜拼盘", "price": 88.0},
],
},
"喜茶": {
"merchant_id": "M103",
"service_type": "Milk Tea",
"menu": [
{"product": "芝士奶茶", "price": 25.0},
{"product": "四季春奶茶", "price": 22.0},
],
},
"盒马生鲜": {
"merchant_id": "M104",
"service_type": "Fresh Grocery",
"menu": [
{"product": "有机蔬菜包", "price": 15.0},
{"product": "生鲜大礼包", "price": 99.0},
],
},
"九田家烤肉": {
"merchant_id": "M105",
"service_type": "BBQ",
"menu": [
{"product": "韩式烤牛肉", "price": 128.0},
{"product": "烤五花肉", "price": 78.0},
],
},
}
# 设置已登录用户列表
self.logged_in_users: list[str] = []
# 订单列表
self.orders: list = []
def get_state_dict(self) -> dict:
"""Get the current state dict of the FoodPlatformApi."""
return {
"FoodPlatform": {
"logged_in_users": self.logged_in_users,
"orders": self.orders,
"users": self.users,
},
}
def login_food_platform(
self,
username: str,
password: str,
) -> dict[str, bool | str]:
"""使用用户名和密码登录外卖平台。
Args:
username (`str`):
用户的用户名。
password (`str`):
用户的密码。
"""
if not self.wifi:
return {"status": False, "message": "wifi未打开无法登录"}
if username not in self.users:
return {"status": False, "message": "用户不存在"}
if self.users[username]["password"] != password:
return {"status": False, "message": "密码错误"}
# 检查是否已经有用户登录
if username in self.logged_in_users:
return {"status": False, "message": f"{username} 已经登录"}
# 记录已登录用户
self.logged_in_users.append(username)
return {"status": True, "message": f"用户{username}登陆成功!"}
def view_logged_in_users(self) -> dict:
"""查看当前所有登录的用户。"""
if not self.logged_in_users:
return {
"status": False,
"message": "当前没有登录food platform",
}
return {"status": True, "logged_in_users": self.logged_in_users}
def check_balance(self, user_name: str) -> float:
"""查询指定用户的余额。
Args:
user_name (`str`):
用户的用户名。
"""
if user_name in self.users:
return self.users[user_name]["balance"]
else:
return 0.0
def add_food_delivery_order(
self,
username: str,
merchant_name: str,
items: list[dict[str, str | int]],
) -> dict[str, bool | str]:
"""订外卖
Args:
username (`str`):
下订单的用户姓名。
merchant_name (`str`):
下订单的商家名称。
items (`list[dict[str, str | int]]`):
订单中商品的列表,每个商品包含名称和数量。
"""
if username not in self.logged_in_users:
return {
"status": False,
"message": f"用户 {username} 未登录food platform",
}
if merchant_name not in self.merchant_list:
return {"status": False, "message": "商家不存在"}
total_price = 0.0
order_items = []
for item in items:
product_name = item.get("product")
quantity = item.get("quantity", 1)
if not isinstance(quantity, int) or quantity <= 0:
return {
"status": False,
"message": f"无效的数量 {quantity} 对于商品 {product_name}",
}
# 查找商品价格
product_found = False
for product in self.merchant_list[merchant_name]["menu"]:
if product["product"] == product_name:
total_price += product["price"] * quantity
order_items.append(
{
"product": product_name,
"quantity": quantity,
"price_per_unit": product["price"],
},
)
product_found = True
break
if not product_found:
return {
"status": False,
"message": f"商品 {product_name} 不存在于 "
f"{merchant_name} 的菜单中",
}
# 检查余额是否足够
if total_price >= self.users[username]["balance"]:
return {"status": False, "message": "余额不足,无法下单"}
# 扣除余额并创建订单
self.users[username]["balance"] -= total_price
order = {
"user_name": username,
"merchant_name": merchant_name,
"items": order_items,
"total_price": total_price,
}
self.orders.append(order)
return {
"status": True,
"message": f"外卖订单成功下单给 {merchant_name}" f"总金额为 {total_price}",
}
def get_products(
self,
merchant_name: str,
) -> list[dict[str, str | float]] | dict[str, bool | str]:
"""获取特定商家的商品列表。
Args:
merchant_name (`str`):
要获取商品的商家名称。
"""
merchant = self.merchant_list.get(merchant_name)
if merchant:
return merchant["menu"]
else:
return {
"status": False,
"message": f"商家 '{merchant_name}' 不存在",
}
def view_orders(
self,
user_name: str,
) -> dict[str, bool | str | list[dict[str, str | int | float]]]:
"""查看用户的所有订单"""
user_orders = [
order for order in self.orders if order["user_name"] == user_name
]
if not user_orders:
return {"status": False, "message": "用户没有订单记录"}
return {"status": True, "orders": user_orders}
def search_orders(
self,
keyword: str,
) -> dict[str, bool | str | list[dict[str, str | float]]]:
"""根据关键字搜索订单。"""
matched_orders = [
order
for order in self.orders
if keyword.lower() in order["merchant_name"].lower()
or any(
keyword.lower() in item.lower()
for item in order.get("items", [])
)
]
if not matched_orders:
return {"status": False, "message": "没有找到匹配的订单"}
return {"status": True, "orders": matched_orders}

View File

@@ -0,0 +1,340 @@
# -*- coding: utf-8 -*-
"""The Message API in the ACEBench evaluation."""
from datetime import datetime
from ._shared_state import SharedState
class MessageApi(SharedState):
"""The message Api in the ACEBench evaluation."""
tool_functions: list[str] = [
"send_message",
"delete_message",
"view_messages_between_users",
"search_messages",
"get_all_message_times_with_ids",
"get_latest_message_id",
"get_earliest_message_id",
]
def __init__(self, share_state: dict) -> None:
"""Initialize the MessageApi with shared state."""
super().__init__(share_state)
# 设置六个用户
self.max_capacity = 6
self.user_list: dict[str, dict[str, str | int]] = {
"Eve": {
"user_id": "USR100",
"phone_number": "123-456-7890",
"occupation": "Software Engineer",
},
"Frank": {
"user_id": "USR101",
"phone_number": "234-567-8901",
"occupation": "Data Scientist",
},
"Grace": {
"user_id": "USR102",
"phone_number": "345-678-9012",
"occupation": "Product Manager",
},
"Helen": {
"user_id": "USR103",
"phone_number": "456-789-0123",
"occupation": "UX Designer",
},
"Isaac": {
"user_id": "USR104",
"phone_number": "567-890-1234",
"occupation": "DevOps Engineer",
},
"Jack": {
"user_id": "USR105",
"phone_number": "678-901-2345",
"occupation": "Marketing Specialist",
},
}
# 设置六个用户之间的短信记录
# 信息1和reminder配合 信息2和food配合
self.inbox: dict[int, dict[str, str | int]] = {
1: {
"sender_id": "USR100",
"receiver_id": "USR101",
"message": "Hey Frank, don't forget about our meeting on "
"2024-06-11 at 4 PM in Conference Room 1.",
"time": "2024-06-09",
},
2: {
"sender_id": "USR101",
"receiver_id": "USR102",
"message": """你能帮我点一个\"玛格丽特披萨\"的外卖吗,商家是达美乐。""",
"time": "2024-03-09",
},
3: {
"sender_id": "USR102",
"receiver_id": "USR103",
"message": "帮我查一些喜茶有哪些奶茶外卖,买一杯便宜些的奶茶。"
"买完以后记得回复我,回复的内容是(已经买好了)",
"time": "2023-12-05",
},
4: {
"sender_id": "USR103",
"receiver_id": "USR102",
"message": "No problem Helen, I can assist you.",
"time": "2024-09-09",
},
5: {
"sender_id": "USR104",
"receiver_id": "USR105",
"message": "Isaac, are you available for a call?",
"time": "2024-06-06",
},
6: {
"sender_id": "USR105",
"receiver_id": "USR104",
"message": "Yes Jack, let's do it in 30 minutes.",
"time": "2024-01-15",
},
}
self.message_id_counter: int = 6
def get_state_dict(self) -> dict:
"""Get the current state dict of the MessageApi."""
# To avoid the error in ACEBench dataset
inbox_state = {}
for key, value in self.inbox.items():
inbox_state[str(key)] = value
return {
"MessageApi": {
"inbox": inbox_state,
},
}
def send_message(
self,
sender_name: str,
receiver_name: str,
message: str,
) -> dict[str, bool | str]:
"""将一条消息从一个用户发送给另一个用户。
Args:
sender_name (`str`):
发送消息的用户姓名。
receiver_name (`str`):
接收消息的用户姓名。
message (`str`):
要发送的消息内容。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法发送短信"}
if not self.wifi:
return {"status": False, "message": "wifi关闭此时不能发送信息"}
if len(self.inbox) >= self.max_capacity:
return {
"status": False,
"message": "内存容量不够了你需要询问user删除哪一条短信。",
}
# 验证发送者和接收者是否存在
if (
sender_name not in self.user_list
or receiver_name not in self.user_list
):
return {"status": False, "message": "发送者或接收者不存在"}
sender_id = self.user_list[sender_name]["user_id"]
receiver_id = self.user_list[receiver_name]["user_id"]
# 将短信添加到inbox
self.message_id_counter += 1
self.inbox[self.message_id_counter] = {
"sender_id": sender_id,
"receiver_id": receiver_id,
"message": message,
}
return {"status": True, "message": f"短信成功发送给{receiver_name}"}
def delete_message(self, message_id: int) -> dict[str, bool | str]:
"""根据消息 ID 删除一条消息。
Args:
message_id (`int`):
要删除的消息的 ID。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法删除短信"}
if message_id not in self.inbox:
return {"status": False, "message": "短信ID不存在"}
del self.inbox[message_id]
return {"status": True, "message": f"短信ID {message_id} 已成功删除。"}
def view_messages_between_users(
self,
sender_name: str,
receiver_name: str,
) -> dict:
"""获取特定用户发送给另一个用户的所有消息。
Args:
sender_name (`str`):
发送消息的用户姓名。
receiver_name (`str`):
接收消息的用户姓名。
"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法查看短信信息",
}
if sender_name not in self.user_list:
return {"status": False, "message": "发送者不存在"}
if receiver_name not in self.user_list:
return {"status": False, "message": "接收者不存在"}
sender_id = self.user_list[sender_name]["user_id"]
receiver_id = self.user_list[receiver_name]["user_id"]
messages_between_users = []
# 遍历 inbox找出 sender_id 发送给 receiver_id 的短信
for msg_id, msg_data in self.inbox.items():
if (
msg_data["sender_id"] == sender_id
and msg_data["receiver_id"] == receiver_id
):
messages_between_users.append(
{
"id": msg_id,
"sender": sender_name,
"receiver": receiver_name,
"message": msg_data["message"],
},
)
if not messages_between_users:
return {"status": False, "message": "没有找到相关的短信记录"}
return {"status": True, "messages": messages_between_users}
def search_messages(
self,
user_name: str,
keyword: str,
) -> dict:
"""搜索特定用户消息中包含特定关键字的消息。
Args:
user_name (`str`):
要搜索消息的用户姓名。
keyword (`str`):
要在消息中搜索的关键字。
"""
if user_name not in self.user_list:
return {"status": False, "message": "用户不存在"}
user_id = self.user_list[user_name]["user_id"]
matched_messages = []
# 遍历 inbox找到发送或接收中包含关键词的消息
for msg_id, msg_data in self.inbox.items():
if (
user_id in (msg_data["sender_id"], msg_data["receiver_id"])
and keyword.lower() in msg_data["message"].lower()
):
matched_messages.append(
{
"id": msg_id,
"sender_id": msg_data["sender_id"],
"receiver_id": msg_data["receiver_id"],
"message": msg_data["message"],
},
)
if not matched_messages:
return {"status": False, "message": "没有找到包含关键词的短信"}
return {"status": True, "messages": matched_messages}
def get_all_message_times_with_ids(
self,
) -> dict:
"""获取所有短信的时间以及对应的短信编号。"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录获取所有短信的时间以及对应的短信编号。",
}
message_times_with_ids = {
msg_id: msg_data["time"] for msg_id, msg_data in self.inbox.items()
}
return message_times_with_ids
def get_latest_message_id(self) -> dict:
"""获取最近发送的消息的 ID。"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法获取最新发送的短信ID。",
}
if not self.inbox:
return {"status": False, "message": "短信记录为空"}
# 遍历所有短信,找出时间最新的短信
latest_message_id = None
latest_time = None
for message_id, message_data in self.inbox.items():
message_time = datetime.strptime(
str(message_data["time"]),
"%Y-%m-%d",
)
if latest_time is None or message_time > latest_time:
latest_time = message_time
latest_message_id = message_id
return {
"status": True,
"message": f"最新的短信ID是 {latest_message_id}",
"message_id": latest_message_id,
}
def get_earliest_message_id(self) -> dict:
"""获取最早发送的消息的 ID。"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法获取最早发送的短信ID",
}
if not self.inbox:
return {"status": False, "message": "短信记录为空"}
# 遍历所有短信,找出时间最早的短信
earliest_message_id = None
earliest_time = None
for message_id, message_data in self.inbox.items():
message_time = datetime.strptime(
str(message_data["time"]),
"%Y-%m-%d",
)
if earliest_time is None or message_time < earliest_time:
earliest_time = message_time
earliest_message_id = message_id
return {
"status": True,
"message": f"最早的短信ID是 {earliest_message_id}",
"message_id": earliest_message_id,
}

View File

@@ -0,0 +1,214 @@
# -*- coding: utf-8 -*-
"""The reminder API in ACEBench simulation tools."""
from datetime import datetime
from ._shared_state import SharedState
class ReminderApi(SharedState):
"""The reminder Api in the ACEBench evaluation."""
tool_functions: list[str] = [
"view_reminder_by_title",
"add_reminder",
"delete_reminder",
"view_all_reminders",
"mark_as_notified",
"search_reminders",
]
def __init__(self, share_state: dict) -> None:
"""Initialize the Reminder Api in the ACEBench evaluation."""
super().__init__(share_state)
self.max_capacity = 6
self.reminder_list: dict[
int,
dict,
] = {
1: {
"reminder_id": 1001,
"title": "Doctor's Appointment",
"description": "Visit Dr. Smith for a checkup.",
"time": "2024-07-15 09:30",
"notified": False,
},
2: {
"reminder_id": 1002,
"title": "Team Meeting",
"description": "Monthly project review with the team.",
"time": "2024-07-17 11:00",
"notified": False,
},
3: {
"reminder_id": 1003,
"title": "To-do list",
"description": '首先帮Frank在"盒马生鲜"点外卖,'
'需要定两个"生鲜大礼包"再发短信告诉Frank'
'"购买商品的价格是()元"。要把括号换成实际金额,'
"保留一位小数。",
"time": "2024-07-16 11:00",
"notified": False,
},
}
self.reminder_id_counter: int = 3
def get_state_dict(self) -> dict:
"""Get the current state dict of the ReminderApi."""
return {
"ReminderApi": {
"reminder_list": self.reminder_list,
},
}
def _check_capacity(self) -> bool:
"""检查备忘录容量是否已满。"""
return len(self.reminder_list) >= self.max_capacity
def view_reminder_by_title(
self,
title: str,
) -> dict[str, str | bool | dict[str, str | bool | datetime]]:
"""根据提醒的标题查看特定的提醒。
Args:
title (str): 提醒的标题。
Returns:
dict[str, str | bool | dict[str, str | bool | datetime]]:
包含查找状态和提醒详情的字典。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法查看提醒"}
for reminder in self.reminder_list.values():
if reminder["title"] == title:
return {"status": True, "reminder": reminder}
return {"status": False, "message": f"没有找到标题为 '{title}' 的提醒"}
def add_reminder(
self,
title: str,
description: str,
time: datetime,
) -> dict[str, bool | str]:
"""添加一个新的提醒。
Args:
title (str): 提醒标题。
description (str): 提醒描述。
time (datetime): 提醒时间, 一定遵循格式"YYYY-MM-DD HH:MM"
Returns:
dict[str, bool | str]: 包含添加状态和结果的字典。
"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法添加一个新的提醒",
}
if self._check_capacity():
return {"status": False, "message": "提醒容量已满,无法添加新的提醒"}
self.reminder_id_counter += 1
reminder_id = self.reminder_id_counter
self.reminder_list[reminder_id] = {
"reminder_id": reminder_id,
"title": title,
"description": description,
"time": time,
"notified": False,
}
return {"status": True, "message": f"提醒 '{title}' 已成功添加"}
def delete_reminder(self, reminder_id: int) -> dict[str, bool | str]:
"""删除指定的提醒。
Args:
reminder_id (int): 要删除的提醒ID。
Returns:
dict[str, bool | str]: 包含删除状态和结果的字典。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法删除指定的提醒"}
if reminder_id not in self.reminder_list:
return {"status": False, "message": "提醒ID不存在"}
del self.reminder_list[reminder_id]
return {"status": True, "message": f"提醒ID {reminder_id} 已成功删除"}
def view_all_reminders(
self,
) -> dict:
"""查看所有的提醒。
Returns:
dict:
包含所有提醒的字典列表。
"""
if not self.reminder_list:
return {"status": False, "message": "没有任何提醒"}
reminders = []
for reminder in self.reminder_list.values():
reminders.append(
{
"title": reminder["title"],
"description": reminder["description"],
"time": reminder["time"],
"notified": reminder["notified"],
},
)
return {"status": True, "reminders": reminders}
def mark_as_notified(
self,
reminder_id: int,
) -> dict[str, bool | str]:
"""标记提醒为已通知。
Args:
reminder_id (int): 要标记为已通知的提醒ID。
Returns:
dict[str, bool | str]:: 包含操作结果的字典。
"""
if reminder_id not in self.reminder_list:
return {"status": False, "message": "提醒ID不存在"}
self.reminder_list[reminder_id]["notified"] = True
return {"status": True, "message": f"提醒ID {reminder_id} 已标记为已通知"}
def search_reminders(
self,
keyword: str,
) -> dict:
"""根据关键词搜索提醒。
Args:
keyword (str): 搜索关键词。
Returns:
`dict`:
包含匹配提醒的字典列表。
"""
matched_reminders = []
for reminder in self.reminder_list.values():
if (
keyword.lower() in reminder["title"].lower()
or keyword.lower() in reminder["description"].lower()
):
matched_reminders.append(
{
"title": reminder["title"],
"description": reminder["description"],
"time": reminder["time"].strftime("%Y-%m-%d %H:%M"),
},
)
if not matched_reminders:
return {"status": False, "message": "没有找到包含该关键词的提醒"}
return {"status": True, "reminders": matched_reminders}

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""The shared state class for ACEBench simulation tools."""
class SharedState:
"""The sharing state class for ACEBench simulation tools."""
def __init__(self, shared_state: dict) -> None:
"""Initialize the shared state"""
self._shared_state = shared_state
@property
def wifi(self) -> bool:
"""The WI-FI state"""
return self._shared_state["wifi"]
@property
def logged_in(self) -> bool:
"""The logged in state"""
return self._shared_state["logged_in"]

View File

@@ -0,0 +1,834 @@
# -*- coding: utf-8 -*-
# type: ignore
# pylint: disable=too-many-lines
# pylint: disable=too-many-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
# pylint: disable=too-many-return-statements
"""The travel API for the ACEBench simulation tools in AgentScope."""
from datetime import datetime, timedelta
class TravelApi:
"""旅行预订系统类。
提供航班查询、用户认证、预订管理等功能的旅行系统。
支持直飞和中转航班查询、航班预订、预订修改和取消等功能。
"""
tool_functions: list[str] = [
"get_user_details",
"get_flight_details",
"get_reservation_details",
"reserve_flight",
"cancel_reservation",
"modify_flight",
]
def __init__(self) -> None:
"""初始化旅行系统。
设置用户档案和航班信息,包含用户信息、航班数据和预订记录。
"""
# 初始化用户信息
self.users = {
"user1": {
"user_name": "Eve",
"password": "password123",
"cash_balance": 2000.0,
"bank_balance": 50000.0,
"membership_level": "regular",
},
"user2": {
"user_name": "Frank",
"password": "password456",
"cash_balance": 8000.0,
"bank_balance": 8000.0,
"membership_level": "silver",
},
"user3": {
"user_name": "Grace",
"password": "password789",
"cash_balance": 1000.0,
"bank_balance": 5000.0,
"membership_level": "gold",
},
}
# 初始化航班信息
self.flights = [
{
"flight_no": "CA1234",
"origin": "北京",
"destination": "上海",
"depart_time": "2024-07-15 08:00:00",
"arrival_time": "2024-07-15 10:30:00",
"status": "available",
"seats_available": 5,
"economy_price": 1200,
"business_price": 3000,
},
{
"flight_no": "MU5678",
"origin": "上海",
"destination": "北京",
"depart_time": "2024-07-16 09:00:00",
"arrival_time": "2024-07-16 11:30:00",
"status": "available",
"seats_available": 3,
"economy_price": 1900,
"business_price": 3000,
},
{
"flight_no": "CZ4321",
"origin": "上海",
"destination": "北京",
"depart_time": "2024-07-16 20:00:00",
"arrival_time": "2024-07-16 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 2500,
"business_price": 4000,
},
{
"flight_no": "CZ4352",
"origin": "上海",
"destination": "北京",
"depart_time": "2024-07-17 20:00:00",
"arrival_time": "2024-07-17 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1600,
"business_price": 2500,
},
{
"flight_no": "MU3561",
"origin": "北京",
"destination": "南京",
"depart_time": "2024-07-18 08:00:00",
"arrival_time": "2024-07-18 10:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 4000,
},
{
"flight_no": "MU1566",
"origin": "北京",
"destination": "南京",
"depart_time": "2024-07-18 20:00:00",
"arrival_time": "2024-07-18 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 4000,
},
{
"flight_no": "CZ1765",
"origin": "南京",
"destination": "深圳",
"depart_time": "2024-07-17 20:30:00",
"arrival_time": "2024-07-17 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "CZ1765",
"origin": "南京",
"destination": "深圳",
"depart_time": "2024-07-18 12:30:00",
"arrival_time": "2024-07-18 15:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "MH1765",
"origin": "厦门",
"destination": "成都",
"depart_time": "2024-07-17 12:30:00",
"arrival_time": "2024-07-17 15:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "MH2616",
"origin": "成都",
"destination": "厦门",
"depart_time": "2024-07-18 18:30:00",
"arrival_time": "2024-07-18 21:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "MH2616",
"origin": "成都",
"destination": "福州",
"depart_time": "2024-07-16 18:30:00",
"arrival_time": "2024-07-16 21:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
]
# 初始化预订列表
self.reservations = [
{
"reservation_id": "res_1",
"user_id": "user1",
"flight_no": "CA1234",
"payment_method": "bank",
"cabin": "经济舱",
"baggage": 1,
"origin": "北京",
"destination": "上海",
},
{
"reservation_id": "res_2",
"user_id": "user1",
"flight_no": "MU5678",
"payment_method": "bank",
"cabin": "商务舱",
"baggage": 1,
"origin": "上海",
"destination": "北京",
},
{
"reservation_id": "res_3",
"user_id": "user2",
"flight_no": "MH1765",
"payment_method": "bank",
"cabin": "商务舱",
"baggage": 1,
"origin": "厦门",
"destination": "成都",
},
{
"reservation_id": "res_4",
"user_id": "user2",
"flight_no": "MU2616",
"payment_method": "bank",
"cabin": "商务舱",
"baggage": 1,
"origin": "成都",
"destination": "厦门",
},
]
def get_state_dict(self) -> dict:
"""Get the current state dict of the TravelApi."""
return {
"Travel": {
"users": self.users,
"reservations": self.reservations,
},
}
# 根据出发地和到达地查询航班
def get_flight_details(
self,
origin: str = None,
destination: str = None,
) -> list[dict] | str:
"""根据出发地和到达地查询航班的基本信息。
Args:
origin (str, optional): 出发地城市名称。默认为None。
destination (str, optional): 目的地城市名称。默认为None。
Returns:
list[dict] | str: 符合条件的航班列表或无航班的提示信息。
"""
flights = self.flights
# 过滤出发地
if origin:
flights = [
flight for flight in flights if flight["origin"] == origin
]
# 过滤到达地
if destination:
flights = [
flight
for flight in flights
if flight["destination"] == destination
]
if len(flights) == 0:
return "没有符合条件的直达航班"
# 返回查询结果
return [
{
"flight_no": flight["flight_no"],
"origin": flight["origin"],
"destination": flight["destination"],
"depart_time": flight["depart_time"],
"arrival_time": flight["arrival_time"],
"status": flight["status"],
"seats_available": flight["seats_available"],
"economy_price": flight["economy_price"],
"business_price": flight["business_price"],
}
for flight in flights
]
def get_user_details(self, user_id: str, password: str) -> dict:
"""根据用户名和密码查询用户信息。
Args:
user_id (str): 用户ID。
password (str): 用户密码。
Returns:
dict: 用户信息字典(不包含密码)或错误信息。
"""
user = self.users.get(user_id)
if user and user["password"] == password:
return {
key: value for key, value in user.items() if key != "password"
}
return {"status": "error", "message": "用户名或密码不正确"}
def get_reservation_details(
self,
reservation_id: str = None,
user_id: str = None,
) -> list[dict] | dict:
"""根据预订ID或用户ID查询预订信息包括对应航班的基本信息。
Args:
reservation_id (str, optional): 预订ID。默认为None。
user_id (str, optional): 用户ID。默认为None。
Returns:
`list[dict] | dict`:
详细预订信息列表或错误信息字典。
"""
# 根据预订ID或用户ID筛选预订信息
if reservation_id:
reservations = [
reservation
for reservation in self.reservations
if reservation["reservation_id"] == reservation_id
]
elif user_id:
reservations = [
reservation
for reservation in self.reservations
if reservation["user_id"] == user_id
]
else:
return {"status": "error", "message": "请提供有效的预订ID或用户ID"}
# 对每个预订,附加航班信息
detailed_reservations = []
for reservation in reservations:
flight_info = next(
(
flight
for flight in self.flights
if flight["flight_no"] == reservation["flight_no"]
),
None,
)
detailed_reservation = {**reservation, "flight_info": flight_info}
detailed_reservations.append(detailed_reservation)
return detailed_reservations
def authenticate_user(self, user_id: str, password: str) -> dict:
"""验证用户身份。
Args:
user_id (str): 用户ID。
password (str): 用户密码。
Returns:
`dict`:
用户信息字典或错误信息字典。
"""
user = self.users.get(user_id)
if user and user["password"] == password:
return user
return {"status": "error", "message": "用户名或密码不正确"}
def get_baggage_allowance(
self,
membership_level: str,
cabin_class: str,
) -> int:
"""获取用户基于会员等级和舱位的免费托运行李限额。
Args:
membership_level (str): 会员等级 ("regular", "silver", "gold")。
cabin_class (str): 舱位 ("基础经济舱", "经济舱", "商务舱")。
Returns:
int: 免费托运行李数量。
"""
allowance = {
"regular": {"经济舱": 1, "商务舱": 2},
"silver": {"经济舱": 2, "商务舱": 3},
"gold": {"经济舱": 3, "商务舱": 3},
}
return allowance.get(membership_level, {}).get(cabin_class, 0)
def find_transfer_flights(
self,
origin_city: str,
transfer_city: str,
destination_city: str,
) -> list[dict] | str:
"""查找从出发城市到目的地城市的中转航班。
确保第一班航班降落时间早于第二班航班起飞时间。
Args:
origin_city (str): 出发城市。
transfer_city (str): 中转城市。
destination_city (str): 到达城市。
Returns:
list[dict] | str:
满足条件的中转航班列表,每个航班包含两段航程的信息,或无航班提示。
"""
# 获取从出发城市到中转城市的航班
first_leg_flights: list[dict] = [
flight
for flight in self.flights
if flight["origin"] == origin_city
and flight["destination"] == transfer_city
and flight["status"] == "available"
]
# 获取从中转城市到目的地城市的航班
second_leg_flights = [
flight
for flight in self.flights
if flight["origin"] == transfer_city
and flight["destination"] == destination_city
and flight["status"] == "available"
]
# 存储符合条件的中转航班
transfer_flights = []
# 遍历第一段航班和第二段航班,查找符合时间条件的组合
for first_flight in first_leg_flights:
first_arrival = datetime.strptime(
first_flight["arrival_time"],
"%Y-%m-%d %H:%M:%S",
)
for second_flight in second_leg_flights:
second_departure = datetime.strptime(
str(second_flight["depart_time"]),
"%Y-%m-%d %H:%M:%S",
)
# 检查第一班航班降落时间早于第二班航班起飞时间
if first_arrival < second_departure:
transfer_flights.append(
{
"first_leg": first_flight,
"second_leg": second_flight,
},
)
# 返回符合条件的中转航班列表
if transfer_flights:
return transfer_flights
else:
return "未找到符合条件的中转航班。"
def calculate_baggage_fee(
self,
membership_level: str,
cabin_class: str,
baggage_count: int,
) -> float:
"""计算行李费用。
Args:
membership_level (str): 会员等级。
cabin_class (str): 舱位等级。
baggage_count (int): 行李数量。
Returns:
float: 额外行李费用。
"""
free_baggage = {
"regular": {"经济舱": 1, "商务舱": 2},
"silver": {"经济舱": 2, "商务舱": 3},
"gold": {"经济舱": 3, "商务舱": 3},
}
free_limit = free_baggage[membership_level][cabin_class]
additional_baggage = max(baggage_count - free_limit, 0)
return additional_baggage * 50
def update_balance(
self,
user: dict,
payment_method: str,
amount: float,
) -> bool:
"""更新用户的余额。
Args:
user (dict): 用户信息字典。
payment_method (str): 支付方式("cash""bank")。
amount (float): 更新金额(正数表示增加,负数表示减少)。
Returns:
bool: 如果余额充足且更新成功,返回 True否则返回 False。
"""
if payment_method == "cash":
if user["cash_balance"] + amount < 0:
return False # 余额不足
user["cash_balance"] += amount
elif payment_method == "bank":
if user["bank_balance"] + amount < 0:
return False # 余额不足
user["bank_balance"] += amount
return True
def reserve_flight(
self,
user_id: str,
password: str,
flight_no: str,
cabin: str,
payment_method: str,
baggage_count: int,
) -> str:
"""预订航班。
Args:
user_id (str): 用户ID。
password (str): 用户密码。
flight_no (str): 航班号。
cabin (str): 舱位等级。
payment_method (str): 支付方式。
baggage_count (int): 行李数量。
Returns:
str: 预订结果信息。
"""
user = self.authenticate_user(user_id, password)
if not user:
return "认证失败请检查用户ID和密码。"
# 检查航班和座位
flight = next(
(
f
for f in self.flights
if f["flight_no"] == flight_no and f["status"] == "available"
),
None,
)
# 计算航班价格
price: int = (
flight["economy_price"]
if cabin == "经济舱"
else flight["business_price"]
)
total_cost = price
# 计算行李费用
baggage_fee = self.calculate_baggage_fee(
user["membership_level"],
cabin,
baggage_count,
)
total_cost += baggage_fee
# 检查支付方式
if payment_method not in ["cash", "bank"]:
return "支付方式无效"
# 更新预定后的余额
if payment_method == "cash":
if total_cost > self.users.get(user_id)["cash_balance"]:
return "cash余额不足请考虑换一种支付方式"
self.users.get(user_id)["cash_balance"] -= total_cost
else:
if total_cost > self.users.get(user_id)["bank_balance"]:
return "bank余额不足请考虑换一种支付方式"
self.users.get(user_id)["bank_balance"] -= total_cost
# 更新航班信息并生成预订
flight["seats_available"] -= 1
reservation_id = f"res_{len(self.reservations) + 1}"
reservation = {
"reservation_id": reservation_id,
"user_id": user_id,
"flight_no": flight_no,
"payment_method": payment_method,
"cabin": cabin,
"baggage": baggage_count,
}
self.reservations.append(reservation)
return f"预订成功,预订号:{reservation_id}" f"总费用:{total_cost}元(包含行李费用)。"
def modify_flight(
self,
user_id: str,
reservation_id: str,
new_flight_no: str = None,
new_cabin: str = None,
add_baggage: int = 0,
new_payment_method: str = None,
) -> str:
"""修改航班预订,包括更改航班、舱位和行李。
Args:
user_id (str): 用户ID。
reservation_id (str): 预订ID。
new_flight_no (str, optional): 新的航班号。默认为None。
new_cabin (str, optional): 新的舱位。默认为None。
add_baggage (int, optional): 新增托运行李的数量。默认为0。
new_payment_method (str, optional): 新的付款方式。默认为None。
Returns:
str: 修改结果信息。
"""
# 获取对应的预订
reservation = next(
(
r
for r in self.reservations
if r["reservation_id"] == reservation_id
and r["user_id"] == user_id
),
None,
)
if not reservation:
return "预订未找到或用户ID不匹配。"
# 检查当前预订的航班信息
current_flight = next(
(
f
for f in self.flights
if f["flight_no"] == reservation["flight_no"]
),
None,
)
if not current_flight:
return "航班信息未找到。"
# 获取原始支付方式或新提供的支付方式
payment_method = (
new_payment_method
if new_payment_method
else reservation["payment_method"]
)
user = self.users[user_id]
if not user:
return "用户信息未找到。"
# 存储处理结果
result_messages = []
if new_flight_no and new_flight_no != reservation["flight_no"]:
# 更新航班号(若提供)但必须匹配出发地和目的地
new_flight = next(
(f for f in self.flights if f["flight_no"] == new_flight_no),
None,
)
if (
new_flight
and new_flight["origin"] == current_flight["origin"]
and new_flight["destination"] == current_flight["destination"]
):
reservation["flight_no"] = new_flight_no
result_messages.append("航班号已更改。")
else:
return "航班更改失败:新的航班号无效或目的地不匹配。"
# 更新舱位(若提供)并计算价格差价
if new_cabin and new_cabin != reservation.get("cabin"):
price_difference = self.calculate_price_difference(
current_flight,
reservation["cabin"],
new_cabin,
)
reservation["cabin"] = new_cabin
if price_difference > 0:
# 扣除差价
if self.update_balance(
user,
payment_method,
-price_difference,
):
result_messages.append(
f"舱位更改成功。已支付差价: {price_difference}",
)
else:
result_messages.append("余额不足,无法支付舱位差价。")
elif price_difference < 0:
# 退款
self.update_balance(user, payment_method, -price_difference)
result_messages.append(f"舱位更改成功。已退款差价: {-price_difference}")
# 增加托运行李,检查免费限额和计算费用
if add_baggage > 0:
membership = user["membership_level"]
max_free_baggage = self.get_baggage_allowance(
membership,
reservation["cabin"],
)
current_baggage = reservation.get("baggage", 0)
total_baggage = current_baggage + add_baggage
extra_baggage = max(0, total_baggage - max_free_baggage)
baggage_cost = extra_baggage * 50
if baggage_cost > 0:
# 扣除行李费用
if self.update_balance(user, payment_method, -baggage_cost):
result_messages.append(
f"行李已增加。需支付额外费用: {baggage_cost}",
)
else:
result_messages.append("余额不足,无法支付额外行李费用。")
reservation["baggage"] = total_baggage
# 返回最终结果
if not result_messages:
result_messages.append("修改完成,无需额外费用。")
return " ".join(result_messages)
def cancel_reservation(
self,
user_id: str,
reservation_id: str,
reason: str,
) -> str:
"""取消预订。
Args:
user_id (str): 用户ID。
reservation_id (str): 预订ID。
reason (str): 取消原因。
Returns:
str: 取消结果信息。
"""
# 设置默认当前时间为 2024年7月14日早上6点
current_time = datetime(2024, 7, 14, 6, 0, 0)
# 验证用户和预订是否存在
user = self.users.get(user_id, None)
if not user:
return "用户ID无效。"
reservation = next(
(
r
for r in self.reservations
if r["reservation_id"] == reservation_id
and r["user_id"] == user_id
),
None,
)
if not reservation:
return "预订ID无效或与该用户无关。"
# 检查航班信息是否存在
flight = next(
(
f
for f in self.flights
if f["flight_no"] == reservation["flight_no"]
),
None,
)
if not flight:
return "航班信息无效。"
# 检查航班是否已起飞
depart_time = datetime.strptime(
flight["depart_time"],
"%Y-%m-%d %H:%M:%S",
)
if current_time > depart_time:
return "航段已使用,无法取消。"
# 计算距离出发时间
time_until_departure = depart_time - current_time
cancel_fee = 0
refund_amount = 0
# 获取航班价格
flight_price = (
flight["economy_price"]
if reservation["cabin"] == "经济舱"
else flight["business_price"]
)
# 取消政策及退款计算
if reason == "航空公司取消航班":
# 航空公司取消航班,全额退款
refund_amount = flight_price
self.process_refund(user, refund_amount)
return f"航班已取消,您的预订将被免费取消,已退款{refund_amount}元。"
elif time_until_departure > timedelta(days=1):
# 离出发时间超过24小时免费取消
refund_amount = flight_price
self.process_refund(user, refund_amount)
return f"距离出发时间超过24小时免费取消成功已退款{refund_amount}元。"
else:
# 若不符合免费取消条件,可根据需求设置取消费
cancel_fee = flight_price * 0.1 # 假设取消费为票价的10%
refund_amount = flight_price - cancel_fee
self.process_refund(user, refund_amount)
return f"距离出发时间不足24小时已扣除取消费{cancel_fee}元,退款{refund_amount}元。"
def process_refund(self, user: dict, amount: float) -> str:
"""将退款金额添加到用户的现金余额中。
Args:
user (dict): 用户信息字典。
amount (float): 退款金额。
"""
user["cash_balance"] += amount
return f"已成功处理退款,{user['user_name']}的现金余额增加了{amount}元。"
def calculate_price_difference(
self,
flight: dict,
old_cabin: str,
new_cabin: str,
) -> float:
"""计算舱位价格差异。
Args:
flight (dict): 航班信息字典。
old_cabin (str): 原舱位等级。
new_cabin (str): 新舱位等级。
Returns:
float: 价格差异(正数表示需支付差价,负数表示退款)。
"""
cabin_prices = {
"经济舱": flight["economy_price"],
"商务舱": flight["business_price"],
}
old_price = cabin_prices.get(old_cabin, 0)
new_price = cabin_prices.get(new_cabin, 0)
return new_price - old_price

View File

@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
"""The Chinese tools for ACEBench evaluation."""
from functools import wraps
from typing import Callable, Any
from ._ace_tools_api import (
ReminderApi,
FoodPlatformApi,
TravelApi,
MessageApi,
)
from ...message import TextBlock
from ...tool import ToolResponse
def _tool_function_wrapper(get_tool_function: Callable) -> Callable:
"""Wrap the tool function result to be ToolResponse."""
@wraps(get_tool_function)
def wrapper(self: "ACEPhone", name: str) -> Callable:
"""Wrap the tool function to return ToolResponse."""
tool_function = get_tool_function(self, name)
@wraps(tool_function)
def wrapper_tool_function(*args: Any, **kwargs: Any) -> ToolResponse:
"""The wrapped tool function"""
res = tool_function(*args, **kwargs)
return ToolResponse(
content=[
TextBlock(
type="text",
text=str(res),
),
],
)
return wrapper_tool_function
return wrapper
class ACEPhone:
"""Simulate a user phone with various apps and functionalities in
ACEBench. The code is implemented with reference to the
`ACEBench <https://github.com/ACEBench/ACEBench>`_.
"""
def __init__(self) -> None:
"""Initialize the shared state and apps for the ACEPhone."""
self._state = {
"wifi": False,
"logged_in": False,
}
self._message_app = MessageApi(self._state)
self._reminder_app = ReminderApi(self._state)
self._food_platform_app = FoodPlatformApi(self._state)
self._travel = TravelApi()
def turn_on_wifi(self) -> dict[str, bool | str]:
"""开启WiFi连接。"""
self._state["wifi"] = True
return {"status": True, "message": "wifi已经打开"}
def login_device(self) -> dict[str, bool | str]:
"""登录设备。"""
self._state["logged_in"] = True
return {"status": True, "message": "设备已经登录"}
def load_initial_config(self, initial_config: dict) -> None:
"""Load the initial config from the application configuration."""
# Empty initial config
if len(initial_config) == 0:
return
# Fix the typo in ACEBench by renaming "Baspi" to "BaseApi"
if "Baspi" in initial_config:
initial_config["BaseApi"] = initial_config.pop("Baspi")
# Verify state
assert (
"BaseApi" in initial_config
and "wifi" in initial_config["BaseApi"]
and "logged_in" in initial_config["BaseApi"]
), f"Invalid initial config: {initial_config}"
self._state["wifi"] = initial_config["BaseApi"]["wifi"]
self._state["logged_in"] = initial_config["BaseApi"]["logged_in"]
def get_current_state(self) -> list[dict]:
"""Follow ACEBench to get the current state of the ACEPhone."""
return [
{"BaseApi": self._state},
self._message_app.get_state_dict(),
self._reminder_app.get_state_dict(),
self._food_platform_app.get_state_dict(),
self._travel.get_state_dict(),
]
@_tool_function_wrapper
def get_tool_function(self, name: str) -> Callable:
"""Get a tool function by name."""
if name in [
"turn_on_wifi",
"login_device",
]:
return getattr(self, name)
if name in self._message_app.tool_functions:
return getattr(self._message_app, name)
if name in self._food_platform_app.tool_functions:
return getattr(self._food_platform_app, name)
if name in self._reminder_app.tool_functions:
return getattr(self._reminder_app, name)
if name in self._travel.tool_functions:
return getattr(self._travel, name)
raise ValueError(
f"Tool function '{name}' not found in ACEPhone.",
)

View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
"""The base class for benchmark evaluation."""
from abc import ABC, abstractmethod
from typing import Generator
from ._task import Task
class BenchmarkBase(ABC):
"""The base class for benchmark evaluation."""
name: str
"""The name of the benchmark."""
description: str
"""The description of the benchmark."""
def __init__(self, name: str, description: str) -> None:
"""Initialize the benchmark.
Args:
name (`str`):
The name of the benchmark.
description (`str`):
A brief description of the benchmark.
"""
self.name = name
self.description = description
@abstractmethod
def __iter__(self) -> Generator[Task, None, None]:
"""Iterate over the benchmark."""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
def __len__(self) -> int:
"""Get the length of the benchmark."""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
def __getitem__(self, index: int) -> Task:
"""Get the task at the given index."""
raise NotImplementedError("Subclasses must implement this method.")

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""The evaluator module in AgentScope."""
from ._evaluator_base import EvaluatorBase
from ._ray_evaluator import RayEvaluator
from ._general_evaluator import GeneralEvaluator
__all__ = [
"EvaluatorBase",
"RayEvaluator",
"GeneralEvaluator",
]

View File

@@ -0,0 +1,304 @@
# -*- coding: utf-8 -*-
"""The base class for evaluator in evaluation."""
import collections
import json
from abc import abstractmethod
from dataclasses import asdict
from typing import Callable, Coroutine, Any
from collections import defaultdict
from .._solution import SolutionOutput
from .._task import Task
from .._benchmark_base import BenchmarkBase
from .._evaluator_storage import EvaluatorStorageBase
from .._metric_base import MetricType
from ..._utils._common import _get_timestamp
class EvaluatorBase:
"""The class that runs the evaluation process."""
def __init__(
self,
name: str,
benchmark: BenchmarkBase,
n_repeat: int,
storage: EvaluatorStorageBase,
) -> None:
"""Initialize the evaluator.
Args:
name (`str`):
The name of this evaluator.
benchmark: (`BenchmarkBase`):
A benchmark instance inheriting from `BenchmarkBase` that
defines the evaluation dataset.
n_repeat (`int`):
How many times to repeat the evaluation for each task.
storage (`EvaluatorStorageBase`):
A instance inheriting from the child class of
`EvaluatorStorageBase` that supports storing and loading
solution output and evaluation results.
"""
self.name = name
self.benchmark = benchmark
self.n_repeat = n_repeat
self.storage = storage
@abstractmethod
async def run(
self,
solution: Callable[
[Task, Callable],
Coroutine[Any, Any, SolutionOutput],
],
) -> None:
"""Run the evaluation and return the results.
Args:
solution (`Callable[[Task, Callable], Coroutine[Any, Any, \
SolutionOutput]]`):
A async function that takes a `Task` instance and a pre-hook
as input and returns a `SolutionOutput` instance.
"""
async def _save_evaluation_meta(self) -> None:
"""Save the evaluation meta information."""
self.storage.save_evaluation_meta(
{
"evaluation_name": self.name,
"created_at": _get_timestamp(),
"total_repeats": self.n_repeat,
"benchmark": {
"name": self.benchmark.name,
"description": self.benchmark.description,
"total_tasks": len(self.benchmark),
},
"schema_version": 1,
},
)
async def _save_task_meta(self, task: Task) -> None:
"""Save the task meta information.
Args:
task (`Task`):
The task instance.
"""
meta_info = asdict(task)
meta_info.pop("metadata")
self.storage.save_task_meta(
task.id,
meta_info,
)
# pylint: disable=too-many-branches, too-many-statements
async def aggregate(self) -> None:
"""Aggregate the evaluation results and save an overall result."""
meta_info: dict = {
"total_tasks": len(self.benchmark),
"total_repeats": self.n_repeat,
"total_stats": {
"llm": defaultdict(int),
"agent": 0,
"tool": defaultdict(int),
"embedding": defaultdict(int),
"chat_usage": {},
},
"repeats": {},
"schema_version": 1,
}
for repeat_index in range(self.n_repeat):
repeat_id = str(repeat_index)
current_repeat: dict = {
"completed_tasks": 0,
"incomplete_tasks": 0,
"metrics": {},
"completed_ids": [],
"incomplete_ids": [],
"stats": {
"llm": defaultdict(int),
"agent": 0,
"tool": defaultdict(int),
"embedding": defaultdict(int),
"chat_usage": {},
},
}
for task in self.benchmark:
current_stats = self.storage.get_solution_stats(
task.id,
repeat_id,
)
# llm
for model_name, cnt in current_stats.get("llm", {}).items():
current_repeat["stats"]["llm"][model_name] += cnt
# agent
current_repeat["stats"]["agent"] += current_stats.get(
"agent",
0,
)
# tool
for tool_name, cnt in current_stats.get("tool", {}).items():
current_repeat["stats"]["tool"][tool_name] += cnt
# embedding
for embedding_model, cnt in current_stats.get(
"embedding",
{},
).items():
current_repeat["stats"]["embedding"][
embedding_model
] += cnt
# chat usage
for model_name, usage in current_stats.get(
"chat_usage",
{},
).items():
if model_name not in current_repeat["stats"]["chat_usage"]:
current_repeat["stats"]["chat_usage"][
model_name
] = defaultdict(int)
current_repeat["stats"]["chat_usage"][model_name][
"input_tokens"
] += usage.get("input_tokens", 0)
current_repeat["stats"]["chat_usage"][model_name][
"output_tokens"
] += usage.get("output_tokens", 0)
for metric in task.metrics:
# Create a new dict in aggregated_result
if metric.name not in current_repeat["metrics"]:
current_repeat["metrics"][metric.name] = {
"type": metric.metric_type,
"involved_tasks": 0,
"completed_tasks": 0,
"incomplete_tasks": 0,
"aggregation": {},
"distribution": collections.defaultdict(list),
}
# Record the submitted task
current_repeat["metrics"][metric.name][
"involved_tasks"
] += 1
# Not finished
if not self.storage.evaluation_result_exists(
task.id,
repeat_id,
metric.name,
):
if task.id not in current_repeat["incomplete_ids"]:
current_repeat["incomplete_tasks"] += 1
current_repeat["incomplete_ids"].append(task.id)
current_repeat["metrics"][metric.name][
"incomplete_tasks"
] += 1
continue
if task.id not in current_repeat["completed_ids"]:
current_repeat["completed_tasks"] += 1
current_repeat["completed_ids"].append(task.id)
current_repeat["metrics"][metric.name][
"completed_tasks"
] += 1
# Get the evaluation result
eval_result = self.storage.get_evaluation_result(
task.id,
repeat_id,
metric.name,
)
# Record the metric result
if metric.metric_type == MetricType.CATEGORY:
current_repeat["metrics"][metric.name]["distribution"][
eval_result.result
].append(
task.id,
)
elif metric.metric_type == MetricType.NUMERICAL:
current_repeat["metrics"][metric.name]["distribution"][
task.id
] = eval_result.result
print("Repeat ID:", repeat_id)
for metric, value in current_repeat["metrics"].items():
print("\tMetric:", metric)
print("\t\tType:", value["type"])
print("\t\tInvolved tasks:", value["involved_tasks"])
print("\t\tCompleted tasks:", value["completed_tasks"])
print("\t\tIncomplete tasks:", value["incomplete_tasks"])
if value["type"] == MetricType.CATEGORY:
# Count the distribution
for category, task_ids in value["distribution"].items():
value["aggregation"][category] = (
len(task_ids) * 1.0 / value["involved_tasks"]
)
elif value["type"] == MetricType.NUMERICAL:
scores = list(value["distribution"].values())
value["aggregation"] = {
"mean": sum(scores) / value["involved_tasks"],
"max": max(scores),
"min": min(scores),
}
print(
"\t\tAggregation:",
json.dumps(
value["aggregation"],
indent=4,
ensure_ascii=False,
).replace("\n", "\n\t\t"),
)
meta_info["repeats"][repeat_id] = current_repeat
# Aggregate total stats
repeat_stats = current_repeat["stats"]
# llm
for model_name, cnt in repeat_stats.get("llm", {}).items():
meta_info["total_stats"]["llm"][model_name] += cnt
# agent
meta_info["total_stats"]["agent"] += repeat_stats.get("agent", 0)
# tool
for tool_name, cnt in repeat_stats.get("tool", {}).items():
meta_info["total_stats"]["tool"][tool_name] += cnt
# embedding
for embedding_model, cnt in repeat_stats.get(
"embedding",
{},
).items():
meta_info["total_stats"]["embedding"][embedding_model] += cnt
# chat usage
for model_name, usage in repeat_stats.get(
"chat_usage",
{},
).items():
if model_name not in meta_info["total_stats"]["chat_usage"]:
meta_info["total_stats"]["chat_usage"][
model_name
] = defaultdict(int)
meta_info["total_stats"]["chat_usage"][model_name][
"input_tokens"
] += usage.get("input_tokens", 0)
meta_info["total_stats"]["chat_usage"][model_name][
"output_tokens"
] += usage.get("output_tokens", 0)
# save
self.storage.save_aggregation_result(meta_info)

View File

@@ -0,0 +1,178 @@
# -*- coding: utf-8 -*-
"""General evaluator implementation in AgentScope, which is easy to debug
compared to the RayEvaluator."""
from typing import Callable, Awaitable, Coroutine, Any
from ._evaluator_base import EvaluatorBase
from ._in_memory_exporter import _InMemoryExporter
from .._evaluator_storage import EvaluatorStorageBase
from .._task import Task
from .._solution import SolutionOutput
from .._benchmark_base import BenchmarkBase
class GeneralEvaluator(EvaluatorBase):
"""The general evaluator that support users to debug their evaluation"""
def __init__(
self,
name: str,
benchmark: BenchmarkBase,
n_repeat: int,
storage: EvaluatorStorageBase,
n_workers: int,
) -> None:
"""Initialize the evaluator."""
super().__init__(
name=name,
benchmark=benchmark,
n_repeat=n_repeat,
storage=storage,
)
assert isinstance(benchmark, BenchmarkBase)
assert n_repeat >= 1, "n_repeat must be at least 1"
assert n_workers >= 1, "n_workers must be at least 1"
self.benchmark = benchmark
self.n_repeat = n_repeat
self.n_workers = n_workers
async def run_evaluation(
self,
task: Task,
repeat_id: str,
solution_output: SolutionOutput,
) -> None:
"""Run the evaluation for a task and solution result."""
evaluation_results = await task.evaluate(solution_output)
# store the evaluation result
for result in evaluation_results:
self.storage.save_evaluation_result(
task_id=task.id,
repeat_id=repeat_id,
evaluation=result,
)
async def run_solution(
self,
repeat_id: str,
task: Task,
solution: Callable[[Task, Callable], Awaitable[SolutionOutput]],
) -> None:
"""Generate a solution to a task and evaluate."""
if self.storage.solution_result_exists(task.id, repeat_id):
# Obtain from storage
solution_result = self.storage.get_solution_result(
task.id,
repeat_id,
)
else:
from opentelemetry import trace
from opentelemetry.context import attach, detach
from opentelemetry import baggage
tracer = trace.get_tracer(__name__)
# Set baggage
ctx = baggage.set_baggage("task_id", task.id)
ctx = baggage.set_baggage("repeat_id", repeat_id, context=ctx)
# Activate the context
token = attach(ctx)
try:
with tracer.start_as_current_span(
name=f"Solution_{task.id}_{repeat_id}",
):
from ... import _config
_config.trace_enabled = True
# Run the solution
solution_result = await solution(
task,
self.storage.get_agent_pre_print_hook(
task.id,
repeat_id,
),
)
self.storage.save_solution_result(
task.id,
repeat_id,
solution_result,
)
finally:
detach(token)
# Evaluate the solution with the
for metric in task.metrics:
if not self.storage.evaluation_result_exists(
task.id,
repeat_id,
metric.name,
):
await self.run_evaluation(
task,
repeat_id,
solution_result,
)
async def run(
self,
solution: Callable[
[Task, Callable],
Coroutine[Any, Any, SolutionOutput],
],
) -> None:
"""Run the ray-based distributed and parallel evaluation, and get the
results.
Args:
solution (`Callable[[Task, Callable], Coroutine[Any, Any, \
SolutionOutput]]`):
A async function that takes a `Task` instance and a pre-print
hook function as input, returns a `SolutionOutput` instance.
"""
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
exporter = _InMemoryExporter()
span_processor = SimpleSpanProcessor(exporter)
tracer_provider: TracerProvider = trace.get_tracer_provider()
if not isinstance(tracer_provider, TracerProvider):
# Create a new tracer provider if not exists
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(span_processor)
trace.set_tracer_provider(tracer_provider)
await self._save_evaluation_meta()
for task in self.benchmark:
await self._save_task_meta(task)
for repeat_id in range(self.n_repeat):
await self.run_solution(
str(repeat_id),
task,
solution,
)
# Save the exporter data
if (
task.id in exporter.cnt
and str(repeat_id) in exporter.cnt[task.id]
):
self.storage.save_solution_stats(
task.id,
str(repeat_id),
exporter.cnt[task.id][str(repeat_id)],
)
await self.aggregate()

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""An in memory exporter of OpenTelemetry traces for AgentScope evaluator, used
to record the token usage during evaluation."""
from collections import defaultdict
from typing import Sequence
from opentelemetry import baggage
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
from ...tracing._attributes import SpanAttributes, OperationNameValues
class _InMemoryExporter(SpanExporter):
"""An in memory exporter to store the token usage from the ChatModel spans
in OpenTelemetry traces."""
def __init__(self) -> None:
"""Initialize the in memory exporter."""
# Initialize the counter
self.cnt: dict = {}
self._stopped = False
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
"""Exports a batch of telemetry data.
Args:
spans (`Sequence[ReadableSpan]`):
The list of `opentelemetry.trace.Span` objects to be exported
Returns:
`SpanExportResult`:
The result of the export
"""
for span in spans:
task_id = baggage.get_baggage("task_id")
repeat_id = baggage.get_baggage("repeat_id")
if task_id is None or repeat_id is None:
continue
if task_id not in self.cnt:
self.cnt[task_id] = {}
if repeat_id not in self.cnt[task_id]:
self.cnt[task_id][repeat_id] = {
"llm": defaultdict(int),
"agent": 0,
"tool": defaultdict(int),
"embedding": defaultdict(int),
"chat_usage": {},
}
span_kind = span.attributes.get(
SpanAttributes.GEN_AI_OPERATION_NAME,
)
if span_kind == OperationNameValues.CHAT:
model_name = span.attributes.get(
SpanAttributes.GEN_AI_REQUEST_MODEL,
"unknown",
)
self.cnt[task_id][repeat_id]["llm"][model_name] += 1
if (
model_name
not in self.cnt[task_id][repeat_id]["chat_usage"]
):
self.cnt[task_id][repeat_id]["chat_usage"][
model_name
] = defaultdict(int)
self.cnt[task_id][repeat_id]["chat_usage"][model_name][
"input_tokens"
] += span.attributes.get(
SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS,
0,
)
self.cnt[task_id][repeat_id]["chat_usage"][model_name][
"output_tokens"
] += span.attributes.get(
SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS,
0,
)
elif span_kind == OperationNameValues.INVOKE_AGENT:
self.cnt[task_id][repeat_id]["agent"] += 1
elif span_kind == OperationNameValues.EXECUTE_TOOL:
tool_name = span.attributes.get(
SpanAttributes.GEN_AI_TOOL_NAME,
"unknown",
)
self.cnt[task_id][repeat_id]["tool"][tool_name] += 1
elif span_kind == OperationNameValues.EMBEDDINGS:
embedding_model = span.attributes.get(
SpanAttributes.GEN_AI_REQUEST_MODEL,
"unknown",
)
self.cnt[task_id][repeat_id]["embedding"][embedding_model] += 1
return SpanExportResult.SUCCESS
def shutdown(self) -> None:
"""Shuts down the exporter."""
self._stopped = True

View File

@@ -0,0 +1,267 @@
# -*- coding: utf-8 -*-
"""The evaluator base class in agentscope."""
import asyncio
from typing import Callable, Awaitable, Coroutine, Any
from ._in_memory_exporter import _InMemoryExporter
from .._benchmark_base import BenchmarkBase
from .._evaluator._evaluator_base import EvaluatorBase
from .._solution import SolutionOutput
from .._task import Task
from .._evaluator_storage import EvaluatorStorageBase
def _check_ray_available() -> None:
"""Check if ray is available and raise ImportError if not."""
try:
import ray # noqa # pylint: disable=unused-import
except ImportError as e:
raise ImportError(
"Ray is not installed. Please install it with `pip install ray` "
"to use the RayEvaluator.",
) from e
# Create a conditional decorator for ray.remote
def _ray_remote_decorator(cls: Any) -> Any:
"""
Conditional ray.remote decorator that only applies when ray is available.
"""
try:
import ray
return ray.remote(cls)
except ImportError:
return cls
@_ray_remote_decorator
class RayEvaluationActor:
"""
Actor class for running evaluation with ray remote.
"""
@staticmethod
async def run(
storage: EvaluatorStorageBase,
task: Task,
repeat_id: str,
solution_output: SolutionOutput,
) -> None:
"""
Run the evaluation for a task and solution result.
Args:
storage (EvaluatorStorageBase): Evaluator storage.
task (Task): Task to be evaluated.
repeat_id (str): Repeat ID
solution_output (SolutionOutput): output data after execute agents.
"""
evaluation_results = await task.evaluate(solution_output)
# store the evaluation result
for result in evaluation_results:
storage.save_evaluation_result(
task_id=task.id,
repeat_id=repeat_id,
evaluation=result,
)
@_ray_remote_decorator
class RaySolutionActor:
"""
Actor class for running agent solutions with ray remote.
"""
def __init__(self, n_workers: int = 1):
self.eval_actor = RayEvaluationActor.options(
max_concurrency=n_workers,
).remote()
# Set up global exporter for this Actor
self.exporter = _InMemoryExporter()
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
span_processor = SimpleSpanProcessor(self.exporter)
tracer_provider: TracerProvider = trace.get_tracer_provider()
if not isinstance(tracer_provider, TracerProvider):
# Create a new tracer provider if not exists
tracer_provider = TracerProvider()
tracer_provider.add_span_processor(span_processor)
trace.set_tracer_provider(tracer_provider)
async def run(
self,
storage: EvaluatorStorageBase,
repeat_id: str,
task: Task,
solution: Callable[
[Task, Callable],
Coroutine[Any, Any, SolutionOutput],
],
) -> None:
"""Generate a solution to a task and evaluate.
Args:
storage (EvaluatorStorageBase): Evaluator storage.
repeat_id (str): Repeat ID.
task (Task): Task to be evaluated.
solution
(Callable[[Task, Callable], Awaitable[SolutionOutput, Any]]):
callable function to execute agents and generate results.
"""
if storage.solution_result_exists(task.id, repeat_id):
# Obtain from storage
solution_result = storage.get_solution_result(
task.id,
repeat_id,
)
else:
from opentelemetry import trace, baggage
from opentelemetry.context import attach, detach
tracer = trace.get_tracer(__name__)
# Set baggage items
ctx = baggage.set_baggage("task_id", task.id)
ctx = baggage.set_baggage("repeat_id", repeat_id, context=ctx)
# Attach the context with baggage
token = attach(ctx)
try:
with tracer.start_as_current_span(
name=f"Solution_{task.id}_{repeat_id}",
):
from ... import _config
_config.trace_enabled = True
# Run the solution
solution_result = await solution(
task,
storage.get_agent_pre_print_hook(
task.id,
repeat_id,
),
)
finally:
detach(token)
# Ensure all spans are flushed
trace.get_tracer_provider().force_flush()
storage.save_solution_stats(
task.id,
repeat_id,
self.exporter.cnt.get(task.id, {}).get(repeat_id, {}),
)
storage.save_solution_result(
task.id,
repeat_id,
solution_result,
)
# Evaluate the solution with the metrics
futures = []
for metric in task.metrics:
if not storage.evaluation_result_exists(
task.id,
repeat_id,
metric.name,
):
futures.append(
self.eval_actor.run.remote(
storage,
task,
repeat_id,
solution_result,
),
)
if futures:
await asyncio.gather(*futures)
class RayEvaluator(EvaluatorBase):
"""The ray-based evaluator that supports distributed and parallel
evaluation."""
def __init__(
self,
name: str,
benchmark: BenchmarkBase,
n_repeat: int,
storage: EvaluatorStorageBase,
n_workers: int,
) -> None:
"""Initialize the evaluator."""
super().__init__(
name=name,
benchmark=benchmark,
n_repeat=n_repeat,
storage=storage,
)
# Check ray availability early
_check_ray_available()
assert isinstance(benchmark, BenchmarkBase)
assert n_repeat >= 1, "n_repeat must be at least 1"
assert n_workers >= 1, "n_workers must be at least 1"
self.benchmark = benchmark
self.n_repeat = n_repeat
self.n_workers = n_workers
async def run(
self,
solution: Callable[
[Task, Callable],
Awaitable[SolutionOutput] | SolutionOutput,
],
) -> None:
"""Run the ray-based distributed and parallel evaluation, and get the
results.
Args:
solution (`Callable[[Task], SolutionOutput]`):
A sync or async function that takes a `Task` instance as input
and returns a `SolutionOutput` instance.
"""
await self._save_evaluation_meta()
# Create solution actors
futures = []
solution_actor = RaySolutionActor.options(
max_concurrency=self.n_workers,
).remote(n_workers=self.n_workers)
# Iterate over all tasks in the benchmark
for task in self.benchmark:
# Save the task meta information
await self._save_task_meta(task)
# Run n_repeat times
for repeat_id in range(self.n_repeat):
futures.append(
solution_actor.run.remote(
self.storage,
str(repeat_id),
task,
solution,
),
)
# Await all the futures
if futures:
await asyncio.gather(*futures)
# Aggregate the results
await self.aggregate()

View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""The evaluator storage module in AgentScope."""
from ._evaluator_storage_base import EvaluatorStorageBase
from ._file_evaluator_storage import FileEvaluatorStorage
__all__ = [
"EvaluatorStorageBase",
"FileEvaluatorStorage",
]

View File

@@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
"""The evaluator storage base class for storing solution and evaluation
results."""
from abc import abstractmethod
from typing import Any, Callable
from .._metric_base import MetricResult
from .._solution import SolutionOutput
from ...agent import AgentBase
from ...types import JSONSerializableObject
class EvaluatorStorageBase:
"""Used to store the solution results and evaluation results to support
resuming the evaluation process"""
@abstractmethod
def save_solution_result(
self,
task_id: str,
repeat_id: str,
output: SolutionOutput,
**kwargs: Any,
) -> None:
"""Save the solution result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
output (`SolutionOutput`):
The solution output to be saved.
"""
@abstractmethod
def get_evaluation_result(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> MetricResult:
"""Get the evaluation result by the given task id and repeat id
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The metric name.
Returns:
`MetricResult`:
The evaluation result for the given task and repeat ID.
"""
@abstractmethod
def save_evaluation_result(
self,
task_id: str,
repeat_id: str,
evaluation: MetricResult,
**kwargs: Any,
) -> None:
"""Save the evaluation result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
evaluation (`MetricResult`):
The evaluation result to be saved.
"""
@abstractmethod
def get_solution_result(
self,
task_id: str,
repeat_id: str,
**kwargs: Any,
) -> SolutionOutput:
"""Get the solution result for the given task and repeat id.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`SolutionOutput`:
The solution output for the given task and repeat ID.
"""
@abstractmethod
def solution_result_exists(self, task_id: str, repeat_id: str) -> bool:
"""Check if the solution for the given task and repeat is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`bool`:
True if the solution result file exists, False otherwise.
"""
@abstractmethod
def evaluation_result_exists(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> bool:
"""Check if the evaluation result for the given solution and metric
is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The name of the metric.
Returns:
`bool`:
True if the evaluation result file exists, False otherwise.
"""
@abstractmethod
def save_aggregation_result(
self,
aggregation_result: dict,
**kwargs: Any,
) -> None:
"""Save the aggregation result.
Args:
aggregation_result (`dict`):
A dictionary containing the aggregation result.
"""
@abstractmethod
def aggregation_result_exists(
self,
**kwargs: Any,
) -> bool:
"""Check if the aggregation result exists
Returns:
`bool`:
`True` if the aggregation result file exists.
"""
@abstractmethod
def save_evaluation_meta(self, meta_info: dict) -> None:
"""Save the evaluation meta information.
Args:
meta_info (`dict`):
A dictionary containing the meta information.
"""
@abstractmethod
def save_task_meta(
self,
task_id: str,
meta_info: dict[str, JSONSerializableObject],
) -> None:
"""Save the task meta information.
Args:
task_id (`str`):
The task ID.
meta_info (`dict[str, JSONSerializableObject]`):
The task meta information to be saved, which should be JSON
serializable.
"""
@abstractmethod
def save_solution_stats(
self,
task_id: str,
repeat_id: str,
stats: dict,
) -> None:
"""Save the solution statistics information for a given task and
repeat ID.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
stats (`dict`):
A dictionary containing the solution statistics to be saved.
"""
@abstractmethod
def get_solution_stats(
self,
task_id: str,
repeat_id: str,
) -> dict:
"""Get the solution statistics information for a given task and
repeat ID.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`dict`:
A dictionary containing the solution statistics for the given
task and repeat ID.
"""
@abstractmethod
def get_agent_pre_print_hook(
self,
task_id: str,
repeat_id: str,
) -> Callable[[AgentBase, dict], None]:
"""Get a pre-print hook function for the agent to save the agent
printing in the evaluation storage.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`Callable[[AgentBase, dict], None]`:
A hook function that takes an `AgentBase` instance and a
keyword arguments dictionary as input, saving the agent's
printing Msg into the evaluation storage.
"""

View File

@@ -0,0 +1,460 @@
# -*- coding: utf-8 -*-
"""A file system based evaluator storage."""
import json
import os
from json import JSONDecodeError
from typing import Any, Callable
from ._evaluator_storage_base import EvaluatorStorageBase
from .._solution import SolutionOutput
from .._metric_base import MetricResult
from ...agent import AgentBase
from ...message import Msg
from ...types import JSONSerializableObject
class FileEvaluatorStorage(EvaluatorStorageBase):
"""File system based evaluator storage, providing methods to save and
retrieve evaluation results. So that the evaluation process can be resumed
from the last saved state.
The files are organized in a directory structure:
- save_dir/
- evaluation_result.json
- evaluation_meta.json
- {repeat_id}/
- {task_id}/
- solution.json
- evaluation/
- {metric_name}.json
"""
SOLUTION_FILE_NAME = "solution.json"
SOLUTION_STATS_FILE_NAME = "stats.json"
EVALUATION_DIR_NAME = "evaluation"
EVALUATION_RESULT_FILE = "evaluation_result.json"
EVALUATION_META_FILE = "evaluation_meta.json"
TASK_META_FILE = "task_meta.json"
AGENT_PRINTING_LOG = "logging.txt"
def __init__(self, save_dir: str) -> None:
"""Initialize the file evaluator storage."""
self.save_dir = os.path.abspath(save_dir)
def _get_save_path(
self,
task_id: str,
repeat_id: str | None,
*args: str,
) -> str:
"""Get the save path for a given task, repeat ID, and additional path
components.
Args:
task_id (`str`):
The task ID.
repeat_id (`str | None`):
The repeat ID for the task, usually the index of the repeat
evaluation. If None, it will be ignored in the path.
*args (`str`):
Additional path components to be appended.
"""
path_components = [
task_id,
repeat_id,
*args,
]
path = os.path.join(self.save_dir, *[_ for _ in path_components if _])
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def save_solution_result(
self,
task_id: str,
repeat_id: str,
output: SolutionOutput,
**kwargs: Any,
) -> None:
"""Save the solution result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
output (`SolutionOutput`):
The solution output to be saved.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_FILE_NAME,
)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=4)
def save_evaluation_result(
self,
task_id: str,
repeat_id: str,
evaluation: MetricResult,
**kwargs: Any,
) -> None:
"""Save the evaluation result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
evaluation (`MetricResult`):
The evaluation result to be saved.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.EVALUATION_DIR_NAME,
f"{evaluation.name}.json",
)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(evaluation, f, ensure_ascii=False, indent=4)
def get_evaluation_result(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> MetricResult:
"""Get the evaluation result by the given task id and repeat id
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The metric name.
Returns:
`MetricResult`:
The evaluation result for the given task and repeat ID.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.EVALUATION_DIR_NAME,
f"{metric_name}.json",
)
if not os.path.exists(path_file):
raise FileNotFoundError(path_file)
with open(path_file, "r", encoding="utf-8") as f:
evaluation = json.load(f)
return MetricResult(**evaluation)
def get_solution_result(
self,
task_id: str,
repeat_id: str,
**kwargs: Any,
) -> SolutionOutput:
"""Get the solution result for the given task and repeat id from the
file system.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Raises:
`FileNotFoundError`:
If the solution result file does not exist for the given task
and repeat ID.
Returns:
`SolutionOutput`:
The solution output for the given task and repeat ID.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_FILE_NAME,
)
if not os.path.exists(path_file):
raise FileNotFoundError(
f"Solution result for task {task_id} and repeat {repeat_id} "
"not found.",
)
try:
with open(path_file, "r", encoding="utf-8") as f:
solution_data = json.load(f)
except JSONDecodeError as e:
raise JSONDecodeError(
f"Failed to load JSON from {path_file}: {e.msg}",
e.doc,
e.pos,
) from e
return SolutionOutput(**solution_data)
def solution_result_exists(self, task_id: str, repeat_id: str) -> bool:
"""Check if the solution for the given task and repeat is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`bool`:
True if the solution result file exists, False otherwise.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_FILE_NAME,
)
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
def evaluation_result_exists(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> bool:
"""Check if the evaluation result for the given solution and metric
is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The name of the metric.
Returns:
`bool`:
True if the evaluation result file exists, False otherwise.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.EVALUATION_DIR_NAME,
f"{metric_name}.json",
)
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
def save_aggregation_result(
self,
aggregation_result: dict,
**kwargs: Any,
) -> None:
"""Save the aggregation result.
Args:
aggregation_result (`dict`):
A dictionary containing the aggregation result.
"""
path_file = os.path.join(
self.save_dir,
self.EVALUATION_RESULT_FILE,
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(aggregation_result, f, ensure_ascii=False, indent=4)
def aggregation_result_exists(
self,
**kwargs: Any,
) -> bool:
"""Check if the aggregation result exists
Returns:
`bool`:
`True` if the aggregation result file exists.
"""
path_file = os.path.join(
self.save_dir,
self.EVALUATION_RESULT_FILE,
)
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
def save_evaluation_meta(self, meta_info: dict) -> None:
"""Save the evaluation meta information.
Args:
meta_info (`dict`):
A dictionary containing the meta information.
"""
path_file = os.path.join(
self.save_dir,
self.EVALUATION_META_FILE,
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(meta_info, f, ensure_ascii=False, indent=4)
def save_task_meta(
self,
task_id: str,
meta_info: dict[str, JSONSerializableObject],
) -> None:
"""Save the task meta information.
Args:
task_id (`str`):
The task ID.
meta_info (`dict[str, JSONSerializableObject]`):
The task meta information to be saved, which should be JSON
serializable.
"""
path_file = self._get_save_path(
task_id,
None,
self.TASK_META_FILE,
)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(meta_info, f, ensure_ascii=False, indent=4)
def save_solution_stats(
self,
task_id: str,
repeat_id: str,
stats: dict,
) -> None:
"""Save the solution statistics information for a given task and
repeat ID.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
stats (`dict`):
A dictionary containing the solution statistics to be saved.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_STATS_FILE_NAME,
)
if not os.path.exists(path_file):
with open(path_file, "w", encoding="utf-8") as f:
json.dump(stats, f, ensure_ascii=False, indent=4)
def get_solution_stats(
self,
task_id: str,
repeat_id: str,
) -> dict:
"""Get the solution statistics information for a given task and
repeat ID.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`dict`:
A dictionary containing the solution statistics for the given
task and repeat ID.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_STATS_FILE_NAME,
)
if not os.path.exists(path_file):
raise FileNotFoundError(
f"Solution statistics for task {task_id} and repeat "
f"{repeat_id} not found.",
)
try:
with open(path_file, "r", encoding="utf-8") as f:
return json.load(f)
except JSONDecodeError as e:
raise JSONDecodeError(
f"Failed to load JSON from {path_file}: {e.msg}",
e.doc,
e.pos,
) from e
def get_agent_pre_print_hook(
self,
task_id: str,
repeat_id: str,
) -> Callable[[AgentBase, dict], None]:
"""Get a pre-print hook function for the agent to save the agent
printing in the evaluation storage.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`Callable[[AgentBase, dict], None]`:
A hook function that takes an `AgentBase` instance and a
keyword arguments dictionary as input, saving the agent's
printing Msg into the evaluation storage.
"""
def pre_print_hook(_agent: AgentBase, kwargs: dict) -> None:
"""Hook function to save agent's printing."""
msg: Msg | None = kwargs.get("msg", None)
last: bool = kwargs.get("last", False)
if msg is None or not last:
return
# Only save the last message
printing_str = []
for block in msg.get_content_blocks():
match block["type"]:
case "text":
printing_str.append(
f"{msg.name}: {block['text']}",
)
case "thinking":
printing_str.append(
f"{msg.name} (thinking): {block['text']}",
)
case _:
block_str = json.dumps(
block,
ensure_ascii=False,
indent=4,
)
if printing_str:
printing_str.append(block_str)
else:
printing_str.append(f"{msg.name}: {block_str}")
path_file = self._get_save_path(
task_id,
repeat_id,
self.AGENT_PRINTING_LOG,
)
with open(path_file, "a", encoding="utf-8") as f:
f.write("\n".join(printing_str) + "\n")
return pre_print_hook

View File

@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
"""The base class for _metric in evaluation."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..types import JSONSerializableObject
@dataclass
class MetricResult(DictMixin):
"""The result of a _metric."""
name: str
"""The metric name."""
result: str | float | int
"""The metric result."""
created_at: str = field(default_factory=_get_timestamp)
"""The timestamp when the metric result was created."""
message: str | None = field(default_factory=lambda: None)
"""An optional message for the metric result, can be used to provide
additional information or context about the result."""
metadata: dict[str, JSONSerializableObject] | None = field(default=None)
"""Optional metadata for the metric result, can be used to store
additional information related to the metric result."""
class MetricType(str, Enum):
"""The metric type enum."""
CATEGORY = "category"
"""The metric result is a category, e.g. "pass" or "fail"."""
NUMERICAL = "numerical"
"""The metric result is a numerical value, e.g. 0.95 or 100."""
@dataclass
class MetricBase(ABC):
"""The base class for _metric in evaluation."""
name: str
"""The name of the Metric"""
metric_type: MetricType
"""The metric type"""
description: str | None
"""The description of the metric"""
categories: list[str] | None
"""The candidate categories. If `metric_type` is "category", the
categories must be provided, otherwise it should be `None`."""
def __init__(
self,
name: str,
metric_type: MetricType,
description: str | None = None,
categories: list[str] | None = None,
) -> None:
"""Initialize the _metric object.
Args:
name (`str`):
The name of the metric.
metric_type (`MetricType`):
The type of the metric, can be either "category" or
"numerical", which will determine how to display the result.
description (`str`):
The description of the metric.
categories (`list[str] | None`, optional):
The candidate categories. If `metric_type` is "category", the
categories must be provided, otherwise it should be `None`.
"""
self.name = name
self.metric_type = metric_type
self.description = description
if metric_type == MetricType.CATEGORY and categories is None:
raise ValueError(
"Categories must be provided for category metrics.",
)
self.categories = categories
@abstractmethod
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> MetricResult:
"""The call function to calculate the _metric result"""

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
"""Solution class for evaluation tasks."""
from dataclasses import dataclass, field
from typing import Any
from ..message import (
ToolResultBlock,
ToolUseBlock,
TextBlock,
)
from ..types._json import JSONSerializableObject
from .._utils._mixin import DictMixin
@dataclass
class SolutionOutput(DictMixin):
"""The output of a solution in evaluation task"""
success: bool
"""Indicates whether the solution is executed successfully. When the
solution raise exception, this should be set to False."""
output: JSONSerializableObject
"""The final output of the solution."""
trajectory: list[ToolUseBlock | ToolResultBlock | TextBlock]
"""The tool calls and results trajectory"""
meta: dict[str, Any] | None = field(default_factory=lambda: None)
"""Additional metadata for the solution"""
def __getstate__(self) -> dict[str, Any]:
"""Custom pickling to handle dataclass + DictMixin inheritance."""
return self.__dict__.copy()
def __setstate__(self, state: dict[str, Any]) -> None:
"""Custom unpickling to handle dataclass + DictMixin inheritance."""
self.__dict__.update(state)

View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
"""The base class for task in evaluation."""
from dataclasses import dataclass, field
from typing import Any
from ._solution import SolutionOutput
from ._metric_base import MetricBase, MetricResult
from ..types._json import JSONSerializableObject
@dataclass
class Task:
"""The base class for task in evaluation."""
id: str
"""The unique identifier for the task."""
input: JSONSerializableObject
"""The task input, which should be a JSON serializable object."""
ground_truth: JSONSerializableObject
"""The task ground truth if exists, which should be a JSON serializable
object."""
metrics: list[MetricBase]
"""The metrics to evaluate the task, which should be a list of
`MetricBase` objects."""
tags: dict[str, str] | None = field(default_factory=lambda: None)
"""Tags to categorize the task, e.g. `{"difficulty": "easy",
"cate": "math"}`."""
metadata: dict[str, Any] | None = field(
default_factory=lambda: None,
)
"""Additional metadata for the task."""
async def evaluate(self, solution: SolutionOutput) -> list[MetricResult]:
"""Evaluate the task with the given solution.
Args:
solution (`SolutionOutput`):
The solution to evaluate the task with.
Returns:
`MetricResult`:
The result of the evaluation.
"""
evaluations = []
for metric in self.metrics:
result = await metric(solution)
evaluations.append(result)
return evaluations

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The exception module in agentscope."""
from ._exception_base import AgentOrientedExceptionBase
from ._tool import (
ToolInterruptedError,
ToolNotFoundError,
ToolInvalidArgumentsError,
)
__all__ = [
"AgentOrientedExceptionBase",
"ToolInterruptedError",
"ToolNotFoundError",
"ToolInvalidArgumentsError",
]

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""The base exception class in agentscope."""
class AgentOrientedExceptionBase(Exception):
"""The base class for all agent-oriented exceptions. These exceptions are
expect to the captured and exposed to the agent during runtime, so that
agents can handle the error appropriately during the runtime.
"""
def __init__(self, message: str):
"""Initialize the exception with a message."""
super().__init__(message)
self.message = message
def __str__(self) -> str:
"""Return the string representation of the exception."""
return f"{self.__class__.__name__}: {self.message}"

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The tool-related exceptions in agentscope."""
from ._exception_base import AgentOrientedExceptionBase
class ToolNotFoundError(AgentOrientedExceptionBase):
"""Exception raised when a tool was not found."""
class ToolInterruptedError(AgentOrientedExceptionBase):
"""Exception raised when a tool calling was interrupted by the user."""
class ToolInvalidArgumentsError(AgentOrientedExceptionBase):
"""Exception raised when the arguments passed to a tool are invalid."""

View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
"""The formatter module in agentscope."""
from ._formatter_base import FormatterBase
from ._truncated_formatter_base import TruncatedFormatterBase
from ._dashscope_formatter import (
DashScopeChatFormatter,
DashScopeMultiAgentFormatter,
)
from ._anthropic_formatter import (
AnthropicChatFormatter,
AnthropicMultiAgentFormatter,
)
from ._openai_formatter import (
OpenAIChatFormatter,
OpenAIMultiAgentFormatter,
)
from ._gemini_formatter import (
GeminiChatFormatter,
GeminiMultiAgentFormatter,
)
from ._ollama_formatter import (
OllamaChatFormatter,
OllamaMultiAgentFormatter,
)
from ._deepseek_formatter import (
DeepSeekChatFormatter,
DeepSeekMultiAgentFormatter,
)
from ._a2a_formatter import A2AChatFormatter
__all__ = [
"FormatterBase",
"TruncatedFormatterBase",
"DashScopeChatFormatter",
"DashScopeMultiAgentFormatter",
"OpenAIChatFormatter",
"OpenAIMultiAgentFormatter",
"AnthropicChatFormatter",
"AnthropicMultiAgentFormatter",
"GeminiChatFormatter",
"GeminiMultiAgentFormatter",
"OllamaChatFormatter",
"OllamaMultiAgentFormatter",
"DeepSeekChatFormatter",
"DeepSeekMultiAgentFormatter",
"A2AChatFormatter",
]

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
"""The A2A message formatter class."""
import mimetypes
import uuid
from typing import Literal, TYPE_CHECKING
from .._logging import logger
from ._formatter_base import FormatterBase
from ..message import (
Msg,
TextBlock,
URLSource,
Base64Source,
ContentBlock,
)
if TYPE_CHECKING:
from a2a.types import (
Message,
Task,
Part,
)
else:
Message = "a2a.types.Message"
Task = "a2a.types.Task"
Part = "a2a.types.Part"
class A2AChatFormatter(FormatterBase):
"""A2A message formatter class, which convert AgentScope messages into
A2A message format."""
async def format(self, msgs: list[Msg]) -> Message:
"""Convert AgentScope messages into a A2A message object. Note that
A2A server only supports single request message, so the input msgs
list will be merged into a single A2A Message.
.. note:: Note the A2A protocol receives a single message per request,
so multi-message inputs will be merged into one A2A Message with role
'user'.
Args:
msgs (`list[Msg]`):
List of AgentScope Msg objects to be converted.
Returns:
`Message`:
The converted A2A Message object.
"""
from a2a.types import (
Part,
TextPart,
FilePart,
FileWithUri,
FileWithBytes,
DataPart,
Role,
Message,
)
self.assert_list_of_msgs(msgs)
parts = []
for msg in msgs:
for block in msg.get_content_blocks():
block_type = block.get("type")
if block_type == "text" and block.get("text"):
parts.append(
Part(
root=TextPart(
text=block.get("text"),
),
),
)
elif block_type == "thinking" and block.get("thinking"):
parts.append(
Part(
root=TextPart(
text=block.get("thinking"),
),
),
)
elif block_type in [
"image",
"video",
"audio",
] and block.get("source"):
source = block.get("source", {})
source_type = source.get("type")
if source_type == "url":
parts.append(
Part(
root=FilePart(
file=FileWithUri(
uri=source.get("url"),
),
),
),
)
elif source_type == "base64":
parts.append(
Part(
root=FilePart(
file=FileWithBytes(
bytes=source.get("data"),
mime_type=source.get("media_type"),
),
),
),
)
else:
raise ValueError(
f"Unsupported source type: {source_type}",
)
elif block_type in ["tool_use", "tool_result"]:
parts.append(
Part(
root=DataPart(
data=block,
),
),
)
else:
logger.error(
"Unsupported block type %s in A2AFormatter.",
block_type,
)
a2a_message = Message(
message_id=str(uuid.uuid4()),
role=Role.user,
parts=parts,
)
return a2a_message
async def format_a2a_message(self, name: str, message: Message) -> Msg:
"""Convert A2A Message object back to AgentScope Msg format.
Args:
name (`str`):
The name of the message sender.
message (`Message`):
The A2A Message object to be converted.
Returns:
`list[Msg]`:
List of converted AgentScope Msg objects.
"""
from a2a.types import Role
content = []
metadata = None
for part in message.parts:
content.append(
await self._format_a2a_part(part),
)
if message.role == Role.user:
role: Literal["user", "assistant"] = "user"
elif message.role == Role.agent:
role = "assistant"
else:
raise ValueError(
f"Unsupported role: {message.role} in A2A message.",
)
return Msg(
name=name,
role=role,
content=content,
metadata=metadata,
)
@staticmethod
def _guess_type(
uri: str | None = None,
mime_type: str | None = None,
) -> Literal["image", "video", "audio", "unknown"]:
"""Guess the content type from the uri or mime type.
Args:
uri (`str | None`, optional):
The uri of the content.
mime_type (`str | None`, optional):
The mime type of the content.
Returns:
`Literal["image", "video", "audio", "unknown"]`:
The guessed content type.
"""
if mime_type is None and uri is None:
raise ValueError(
"Either uri or mime_type must be provided to guess the"
" content type.",
)
if mime_type is None:
mime_type, _encoding = mimetypes.guess_type(uri or "")
if isinstance(mime_type, str):
if mime_type.startswith("image/"):
return "image"
if mime_type.startswith("video/"):
return "video"
if mime_type.startswith("audio/"):
return "audio"
return "unknown"
async def format_a2a_task(self, name: str, task: Task) -> list[Msg]:
"""Convert A2A Task object back to AgentScope Msg format.
Args:
name (`str`):
The name of the message sender.
task (`Task`):
The A2A Task object to be converted.
Returns:
`list[Msg]`:
Converted AgentScope Msg objects.
"""
msgs = []
if task.status and task.status.message:
msgs.append(
await self.format_a2a_message(name, task.status.message),
)
merged_msgs = []
for msg in msgs:
if merged_msgs and merged_msgs[-1].role == msg.role:
merged_msgs[-1].content.extend(msg.content)
else:
merged_msgs.append(msg)
if task.artifacts:
for artifact in task.artifacts:
artifact_content = [
await self._format_a2a_part(_) for _ in artifact.parts
]
if merged_msgs and merged_msgs[-1].role == "assistant":
merged_msgs[-1].content.extend(artifact_content)
merged_msgs[-1].metadata = artifact.metadata
else:
merged_msgs.append(
Msg(
name=name,
role="assistant",
content=artifact_content,
metadata=artifact.metadata,
),
)
return merged_msgs
async def _format_a2a_part(self, part: Part) -> ContentBlock:
"""Convert a single A2A Part object into AgentScope ContentBlock.
.. note:: We will try to convert the `DataPart` into tool use and tool
result blocks if possible.
Args:
part (`Part`):
The A2A Part object to be converted.
Returns:
`ContentBlock`:
The converted AgentScope ContentBlock.
"""
from a2a.types import (
TextPart,
FilePart,
FileWithUri,
FileWithBytes,
DataPart,
)
if isinstance(part.root, TextPart):
return TextBlock(
type="text",
text=part.root.text,
)
if isinstance(part.root, FilePart):
if isinstance(part.root.file, FileWithUri):
return { # type: ignore[return-value, misc]
"type": self._guess_type(
part.root.file.uri,
part.root.file.mime_type,
),
"source": URLSource(
type="url",
url=part.root.file.uri,
),
}
if isinstance(part.root.file, FileWithBytes):
return { # type: ignore[return-value, misc]
"type": self._guess_type(
mime_type=part.root.file.mime_type,
),
"source": Base64Source(
type="base64",
media_type=part.root.file.mime_type
or "application/octet-stream",
data=part.root.file.bytes,
),
}
raise ValueError(
f"Unsupported File type: {type(part.root.file)} in A2A"
"message.",
)
if isinstance(part.root, DataPart):
# Maybe the tool use and tool result blocks
if {
"type",
"name",
"input",
"id",
} <= part.root.data.keys() and part.root.data[
"type"
] == "tool_use":
return part.root.data
if {
"type",
"name",
"output",
"id",
} <= part.root.data.keys() and part.root.data[
"type"
] == "tool_result":
return part.root.data
# TODO: what about the other data parts?
return TextBlock(
type="text",
text=str(part.root.data),
)
raise ValueError(
f"Unsupported Part type: {type(part.root)} in A2A message"
f": {part.root}",
)

View File

@@ -0,0 +1,253 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Anthropic formatter module."""
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
class AnthropicChatFormatter(TruncatedFormatterBase):
"""The Anthropic formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Anthropic API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
.. note:: Anthropic suggests always passing all previous thinking
blocks back to the API in subsequent calls to maintain reasoning
continuity. For more details, please refer to
`Anthropic's documentation
<https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#preserving-thinking-blocks>`_.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for index, msg in enumerate(msgs):
content_blocks = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ in ["thinking", "text", "image"]:
content_blocks.append({**block})
elif typ == "tool_use":
content_blocks.append(
{
"id": block.get("id"),
"type": "tool_use",
"name": block.get("name"),
"input": block.get("input", {}),
},
)
elif typ == "tool_result":
output = block.get("output")
if output is None:
content_value = [{"type": "text", "text": None}]
elif isinstance(output, list):
content_value = output
else:
content_value = [{"type": "text", "text": str(output)}]
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": block.get("id"),
"content": content_value,
},
],
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
# Claude only allow the first message to be system message
if msg.role == "system" and index != 0:
role = "user"
else:
role = msg.role
msg_anthropic = {
"role": role,
"content": content_blocks or None,
}
# When both content and tool_calls are None, skipped
if msg_anthropic["content"] or msg_anthropic.get("tool_calls"):
messages.append(msg_anthropic)
return messages
class AnthropicMultiAgentFormatter(TruncatedFormatterBase):
"""
Anthropic formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Anthropic API."""
return await AnthropicChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Anthropic API."""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Anthropic format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{
"text": "\n".join(accumulated_text),
"type": "text",
},
)
accumulated_text.clear()
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{
"text": "\n".join(accumulated_text),
"type": "text",
},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"type": "text",
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append(
{"type": "text", "text": "</history>"},
)
if conversation_blocks:
formatted_msgs.append(
{
"role": "user",
"content": conversation_blocks,
},
)
return formatted_msgs

View File

@@ -0,0 +1,639 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The dashscope formatter module."""
import json
import os.path
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _is_accessible_local_file
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
VideoBlock,
ToolUseBlock,
ToolResultBlock,
URLSource,
)
from ..token import TokenCounterBase
def _format_dashscope_media_block(
block: ImageBlock | AudioBlock,
) -> dict[str, str]:
"""Format an image or audio block for DashScope API.
Args:
block (`ImageBlock` | `AudioBlock`):
The image or audio block to format.
Returns:
`dict[str, str]`:
A dictionary with "image" or "audio" key and the formatted URL or
data URI as value.
Raises:
`NotImplementedError`:
If the source type is not supported.
"""
typ = block["type"]
source = block["source"]
if source["type"] == "url":
url = source["url"]
if _is_accessible_local_file(url):
return {typ: "file://" + os.path.abspath(url)}
else:
# treat as web url
return {typ: url}
elif source["type"] == "base64":
media_type = source["media_type"]
base64_data = source["data"]
return {
typ: f"data:{media_type};base64,{base64_data}",
}
else:
raise NotImplementedError(
f"Unsupported source type '{source.get('type')}' "
f"for {typ} block.",
)
def _reformat_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Reformat the content to be compatible with HuggingFaceTokenCounter.
This function processes a list of messages and converts multi-part
text content into single string content when all parts are plain text.
This is necessary for compatibility with HuggingFaceTokenCounter which
expects simple string content rather than structured content with
multiple parts.
Args:
messages (list[dict[str, Any]]):
A list of message dictionaries where each message may contain a
"content" field. The content can be either:
- A string (unchanged)
- A list of content items, where each item is a dict that may
contain "text", "type", and other fields
Returns:
list[dict[str, Any]]:
A list of reformatted messages. For messages where all content
items are plain text (have "text" field and either no "type"
field or "type" == "text"), the content list is converted to a
single newline-joined string. Other messages remain unchanged.
Example:
.. code-block:: python
# Case 1: All text content - will be converted
messages = [
{
"role": "user",
"content": [
{"text": "Hello", "type": "text"},
{"text": "World", "type": "text"}
]
}
]
result = _reformat_messages(messages)
print(result[0]["content"])
# Output: "Hello\nWorld"
# Case 2: Mixed content - will remain unchanged
messages = [
{
"role": "user",
"content": [
{"text": "Hello", "type": "text"},
{"image_url": "...", "type": "image"}
]
}
]
result = _reformat_messages(messages) # remain unchanged
print(type(result[0]["content"]))
# Output: <class 'list'>
"""
for message in messages:
content = message.get("content", [])
is_all_text = True
texts = []
for item in content:
if not isinstance(item, dict) or "text" not in item:
is_all_text = False
break
if "type" in item and item["type"] != "text":
is_all_text = False
break
if item["text"]:
texts.append(item["text"])
if is_all_text and texts:
message["content"] = "\n".join(texts)
return messages
class DashScopeChatFormatter(TruncatedFormatterBase):
"""The DashScope formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
.. warning::
Known Issues with DashScope API:
1. **Missing content field**: When messages lack the 'content' field,
qwen-vl-max models will raise ``KeyError: 'content'``.
2. **None content value**: When content is ``None``, qwen-vl-max models
will raise ``TypeError: 'NoneType' object is not iterable``.
3. **Empty text in content**: When content contains
``[{"text": None}]``, qwen3-max may repeatedly invoke tools
multiple times. Note that when qwen3-max initiates tool calls,
the returned message contains ``"content": ""``.
To avoid these issues, this formatter assigns content as an empty
list ``[]`` for messages without valid content blocks.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
VideoBlock,
ToolUseBlock,
ToolResultBlock,
]
def __init__(
self,
promote_tool_result_images: bool = False,
promote_tool_result_audios: bool = False,
promote_tool_result_videos: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_audios (`bool`, defaults to `False`):
Whether to promote audios from tool results to user messages.
Most LLM APIs don't support audios in tool result blocks, but
do support them in user message blocks. When `True`, audios are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_videos (`bool`, defaults to `False`):
Whether to promote videos from tool results to user messages.
Most LLM APIs don't support videos in tool result blocks, but
do support them in user message blocks. When `True`, videos are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter, max_tokens)
self.promote_tool_result_images = promote_tool_result_images
self.promote_tool_result_audios = promote_tool_result_audios
self.promote_tool_result_videos = promote_tool_result_videos
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into DashScope API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
formatted_msgs: list[dict] = []
i = 0
while i < len(msgs):
msg = msgs[i]
content_blocks: list[dict[str, Any]] = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append(
{
"text": block.get("text"),
},
)
elif typ in ["image", "audio", "video"]:
content_blocks.append(
_format_dashscope_media_block(
block, # type: ignore[arg-type]
),
)
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
# First add the tool result message in DashScope API format
formatted_msgs.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": textual_output,
"name": block.get("name"),
},
)
# Then, handle the multimodal data if any
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
elif (
multimodal_block["type"] == "audio"
and self.promote_tool_result_audios
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The audio from '{url}': ",
),
AudioBlock(
type="audio",
source=URLSource(
type="url",
url=url,
),
),
],
)
elif (
multimodal_block["type"] == "video"
and self.promote_tool_result_videos
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The video from '{url}': ",
),
VideoBlock(
type="video",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
f"the media contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
msg_dashscope = {
"role": msg.role,
"content": content_blocks,
}
if tool_calls:
msg_dashscope["tool_calls"] = tool_calls
if msg_dashscope["content"] or msg_dashscope.get("tool_calls"):
formatted_msgs.append(msg_dashscope)
# Move to next message
i += 1
return _reformat_messages(formatted_msgs)
class DashScopeMultiAgentFormatter(TruncatedFormatterBase):
"""DashScope formatter for multi-agent conversations, where more than
a user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the DashScope API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
AudioBlock,
VideoBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
promote_tool_result_audios: bool = False,
promote_tool_result_videos: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_audios (`bool`, defaults to `False`):
Whether to promote audios from tool results to user messages.
Most LLM APIs don't support audios in tool result blocks, but
do support them in user message blocks. When `True`, audios are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_videos (`bool`, defaults to `False`):
Whether to promote videos from tool results to user messages.
Most LLM APIs don't support videos in tool result blocks, but
do support them in user message blocks. When `True`, videos are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
self.promote_tool_result_audios = promote_tool_result_audios
self.promote_tool_result_videos = promote_tool_result_videos
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the DashScope API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DashScope API.
"""
return await DashScopeChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
promote_tool_result_audios=self.promote_tool_result_audios,
promote_tool_result_videos=self.promote_tool_result_videos,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into a user message with conversation history tags. For the
first agent message, it will include the conversation history prompt.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DashScope API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required DashScope format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "audio", "video"]:
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
if block["source"]["type"] == "url":
url = block["source"]["url"]
if _is_accessible_local_file(url):
conversation_blocks.append(
{
block["type"]: "file://"
+ os.path.abspath(url),
},
)
else:
conversation_blocks.append({block["type"]: url})
elif block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
conversation_blocks.append(
{
block[
"type"
]: f"data:{media_type};base64,{base64_data}",
},
)
else:
logger.warning(
"Unsupported block type %s in the message, "
"skipped.",
block["type"],
)
if accumulated_text:
conversation_blocks.append({"text": "\n".join(accumulated_text)})
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
formatted_msgs.append(
{
"role": "user",
"content": conversation_blocks,
},
)
return _reformat_messages(formatted_msgs)
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for DashScope API."""
return {
"role": "system",
"content": msg.get_text_content(),
}

View File

@@ -0,0 +1,265 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The DeepSeek formatter module."""
import json
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
class DeepSeekChatFormatter(TruncatedFormatterBase):
"""The DeepSeek formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = False
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into DeepSeek API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for msg in msgs:
content_blocks: list = []
reasoning_content_blocks: list = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "thinking":
reasoning_content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
textual_output, _ = self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
)
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": textual_output,
"name": block.get("name"),
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
reasoning_msg = "\n".join(
reasoning.get("thinking", "")
for reasoning in reasoning_content_blocks
)
msg_deepseek = {
"role": msg.role,
"content": content_msg or None,
}
if reasoning_msg:
msg_deepseek["reasoning_content"] = reasoning_msg
if tool_calls:
msg_deepseek["tool_calls"] = tool_calls
if msg_deepseek["content"] or msg_deepseek.get("tool_calls"):
messages.append(msg_deepseek)
return messages
class DeepSeekMultiAgentFormatter(TruncatedFormatterBase):
"""
DeepSeek formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = False
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DeepSeek multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the DeepSeek API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DeepSeek API.
"""
return await DeepSeekChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the DeepSeek API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DeepSeek API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required DeepSeek format
formatted_msgs: list[dict] = []
conversation_blocks: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""The formatter module."""
from abc import abstractmethod
from typing import Any, List, Tuple, Sequence
from .._utils._common import _save_base64_data
from ..message import Msg, AudioBlock, ImageBlock, TextBlock, VideoBlock
class FormatterBase:
"""The base class for formatters."""
@abstractmethod
async def format(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]:
"""Format the Msg objects to a list of dictionaries that satisfy the
API requirements."""
@staticmethod
def assert_list_of_msgs(msgs: list[Msg]) -> None:
"""Assert that the input is a list of Msg objects.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be validated.
"""
if not isinstance(msgs, list):
raise TypeError("Input must be a list of Msg objects.")
for msg in msgs:
if not isinstance(msg, Msg):
raise TypeError(
f"Expected Msg object, got {type(msg)} instead.",
)
@staticmethod
def convert_tool_result_to_string(
output: str | List[TextBlock | ImageBlock | AudioBlock | VideoBlock],
) -> tuple[
str,
Sequence[
Tuple[
str,
ImageBlock | AudioBlock | TextBlock | VideoBlock,
]
],
]:
"""Turn the tool result list into a textual output to be compatible
with the LLM API that doesn't support multimodal data in the tool
result.
For URL-based images, the URL is included in the list. For
base64-encoded images, the local file path where the image is saved
is included in the returned list.
Args:
output (`str | List[TextBlock | ImageBlock | AudioBlock | \
VideoBlock]`):
The output of the tool response, including text and multimodal
data like images and audio.
Returns:
`tuple[str, list[Tuple[str, ImageBlock | AudioBlock | VideoBlock \
TextBlock]]]`:
A tuple containing the textual representation of the tool
result and a list of tuples. The first element of each tuple
is the local file path or URL of the multimodal data, and the
second element is the corresponding block.
"""
if isinstance(output, str):
return output, []
textual_output = []
multimodal_data = []
for block in output:
assert isinstance(block, dict) and "type" in block, (
f"Invalid block: {block}, a TextBlock, ImageBlock, "
f"AudioBlock, or VideoBlock is expected."
)
if block["type"] == "text":
textual_output.append(block["text"])
elif block["type"] in ["image", "audio", "video"]:
assert "source" in block, (
f"Invalid {block['type']} block: {block}, 'source' key "
"is required."
)
source = block["source"]
# Save the image locally and return the file path
if source["type"] == "url":
textual_output.append(
f"The returned {block['type']} can be found "
f"at: {source['url']}",
)
path_multimodal_file = source["url"]
elif source["type"] == "base64":
path_multimodal_file = _save_base64_data(
source["media_type"],
source["data"],
)
textual_output.append(
f"The returned {block['type']} can be found "
f"at: {path_multimodal_file}",
)
else:
raise ValueError(
f"Invalid image source: {block['source']}, "
"expected 'url' or 'base64'.",
)
multimodal_data.append(
(path_multimodal_file, block),
)
else:
raise ValueError(
f"Unsupported block type: {block['type']}, "
"expected 'text', 'image', 'audio', or 'video'.",
)
if len(textual_output) == 1:
return textual_output[0], multimodal_data
else:
return "\n".join("- " + _ for _ in textual_output), multimodal_data

View File

@@ -0,0 +1,507 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Google gemini API formatter in agentscope."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._utils._common import _get_bytes_from_web_url
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
VideoBlock,
URLSource,
)
from .._logging import logger
from ..token import TokenCounterBase
def _format_gemini_media_block(
media_block: ImageBlock | AudioBlock | VideoBlock,
) -> dict[str, Any]:
"""Format an image/audio/video block for Gemini API.
Args:
media_block (`ImageBlock | AudioBlock | VideoBlock`):
The media block to format.
Returns:
`dict[str, Any]`:
A dictionary with "inline_data" key in Gemini format.
Raises:
`ValueError`:
If the source type is not supported.
"""
source = media_block["source"]
if source["type"] == "base64":
return {
"inline_data": {
"data": source["data"],
"mime_type": source["media_type"],
},
}
elif source["type"] == "url":
return {
"inline_data": _to_gemini_inline_data(source["url"]),
}
else:
raise ValueError(
f"Unsupported source type: {source['type']}",
)
def _to_gemini_inline_data(url: str) -> dict:
"""Convert url into the Gemini API required format."""
parsed_url = urlparse(url)
extension = url.split(".")[-1].lower()
# Pre-calculate media type from extension (image/audio/video).
typ = None
for k, v in GeminiChatFormatter.supported_extensions.items():
if extension in v:
typ = k
break
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
data = _get_bytes_from_web_url(url)
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
elif os.path.exists(url):
# Local file
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
class GeminiChatFormatter(TruncatedFormatterBase):
"""The Gemini formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
supported_extensions: dict[str, list[str]] = {
"image": ["png", "jpeg", "webp", "heic", "heif"],
"video": [
"mp4",
"mpeg",
"mov",
"avi",
"x-flv",
"mpg",
"webm",
"wmv",
"3gpp",
],
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
}
def __init__(
self,
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Gemini chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter, max_tokens)
self.promote_tool_result_images = promote_tool_result_images
async def _format(
self,
msgs: list[Msg],
) -> list[dict]:
"""Format message objects into Gemini API required format."""
self.assert_list_of_msgs(msgs)
messages: list = []
i = 0
while i < len(msgs):
msg = msgs[i]
parts = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
parts.append(
{
"text": block.get("text"),
},
)
elif typ == "tool_use":
parts.append(
{
"function_call": {
"id": None,
"name": block["name"],
"args": block["input"],
},
"thought_signature": block.get("id", None),
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
# First add the tool result message in DashScope API format
messages.append(
{
"role": "user",
"parts": [
{
"function_response": {
"id": block["id"],
"name": block["name"],
"response": {
"output": textual_output,
},
},
},
],
},
)
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
"the image contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
elif typ in ["image", "audio", "video"]:
parts.append(
_format_gemini_media_block(
block, # type: ignore[arg-type]
),
)
else:
logger.warning(
"Unsupported block type: %s in the message, skipped. ",
typ,
)
role = "model" if msg.role == "assistant" else "user"
if parts:
messages.append(
{
"role": role,
"parts": parts,
},
)
# Move to next message (including inserted messages, which will
# be processed in subsequent iterations)
i += 1
return messages
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
"""The multi-agent formatter for Google Gemini API, where more than a
user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the Gemini API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Gemini multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to be used for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Gemini API."""
return {
"role": "user",
"parts": [
{
"text": msg.get_text_content(),
},
],
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
return await GeminiChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into Gemini API required format
formatted_msgs: list = []
# Collect the multimodal files
conversation_parts: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "video", "audio"]:
# handle the accumulated text as a single part if exists
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
accumulated_text.clear()
# handle the multimodal data
conversation_parts.append(
_format_gemini_media_block(
block, # type: ignore[arg-type]
),
)
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
# Add prompt and <history></history> tags around conversation history
if conversation_parts:
if conversation_parts[0].get("text"):
conversation_parts[0]["text"] = (
conversation_history_prompt
+ "<history>"
+ conversation_parts[0]["text"]
)
else:
conversation_parts.insert(
0,
{"text": conversation_history_prompt + "<history>"},
)
if conversation_parts[-1].get("text"):
conversation_parts[-1]["text"] += "\n</history>"
else:
conversation_parts.append(
{"text": "</history>"},
)
formatted_msgs.append(
{
"role": "user",
"parts": conversation_parts,
},
)
return formatted_msgs

View File

@@ -0,0 +1,441 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Ollama formatter module."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _get_bytes_from_web_url
from ..message import (
Msg,
TextBlock,
ImageBlock,
ToolUseBlock,
ToolResultBlock,
URLSource,
)
from ..token import TokenCounterBase
def _format_ollama_image_block(
image_block: ImageBlock,
) -> str:
"""Format an image block for Ollama API.
Args:
image_block (`ImageBlock`):
The image block to format.
Returns:
`str`:
Base64 encoded image data as a string.
Raises:
`ValueError`:
If the source type is not supported.
"""
source = image_block["source"]
if source["type"] == "url":
return _convert_ollama_image_url_to_base64_data(source["url"])
elif source["type"] == "base64":
return source["data"]
else:
raise ValueError(
f"Unsupported image source type: {source['type']}",
)
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
"""Convert image url to base64."""
parsed_url = urlparse(url)
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
data = _get_bytes_from_web_url(url)
return data
if os.path.exists(url):
# Local file
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return data
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
class OllamaChatFormatter(TruncatedFormatterBase):
"""The Ollama formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
participants in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Ollama chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter, max_tokens)
self.promote_tool_result_images = promote_tool_result_images
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Ollama API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list = []
i = 0
while i < len(msgs):
msg = msgs[i]
content_blocks: list = []
tool_calls = []
images = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": block.get("input", {}),
},
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": textual_output,
"name": block.get("name"),
},
)
# Then, handle the multimodal data if any
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
"the image contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
elif typ == "image":
images.append(
_format_ollama_image_block(
block, # type: ignore[arg-type]
),
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
msg_ollama = {
"role": msg.role,
"content": content_msg or None,
}
if tool_calls:
msg_ollama["tool_calls"] = tool_calls
if images:
msg_ollama["images"] = images
if (
msg_ollama["content"]
or msg_ollama.get("images")
or msg_ollama.get("tool_calls")
):
messages.append(msg_ollama)
# Move to next message
i += 1
return messages
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
"""
Ollama formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Ollama multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Ollama API."""
return {
"role": "system",
"content": msg.get_text_content(),
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Ollama API.
"""
return await OllamaChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the ollama API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Ollama format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
images = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
images.append(_format_ollama_image_block(block))
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if images:
user_message["images"] = images
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,530 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches, too-many-nested-blocks
"""The OpenAI formatter for agentscope."""
import base64
import json
import os
from typing import Any
from urllib.parse import urlparse
import requests
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import (
Msg,
URLSource,
TextBlock,
ImageBlock,
AudioBlock,
Base64Source,
ToolUseBlock,
ToolResultBlock,
)
from ..token import TokenCounterBase
def _format_openai_image_block(
image_block: ImageBlock,
) -> dict[str, Any]:
"""Format an image block for OpenAI API.
Args:
image_block (`ImageBlock`):
The image block to format.
Returns:
`dict[str, Any]`:
A dictionary with "type" and "image_url" keys in OpenAI format.
Raises:
`ValueError`:
If the source type is not supported.
"""
source = image_block["source"]
if source["type"] == "url":
url = _to_openai_image_url(source["url"])
elif source["type"] == "base64":
data = source["data"]
media_type = source["media_type"]
url = f"data:{media_type};base64,{data}"
else:
raise ValueError(
f"Unsupported image source type: {source['type']}",
)
return {
"type": "image_url",
"image_url": {
"url": url,
},
}
def _to_openai_image_url(url: str) -> str:
"""Convert an image url to openai format. If the given url is a local
file, it will be converted to base64 format. Otherwise, it will be
returned directly.
Args:
url (`str`):
The local or public url of the image.
"""
# See https://platform.openai.com/docs/guides/vision for details of
# support image extensions.
support_image_extensions = (
".png",
".jpg",
".jpeg",
".gif",
".webp",
)
parsed_url = urlparse(url)
lower_url = url.lower()
# Web url
if not os.path.exists(url) and parsed_url.scheme != "":
path_lower = parsed_url.path if parsed_url.path else parsed_url.netloc
if any(path_lower.endswith(_) for _ in support_image_extensions):
return url
# Check if it is a local file
elif os.path.exists(url) and os.path.isfile(url):
if any(lower_url.endswith(_) for _ in support_image_extensions):
with open(url, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode(
"utf-8",
)
extension = parsed_url.path.lower().split(".")[-1]
mime_type = f"image/{extension}"
return f"data:{mime_type};base64,{base64_image}"
raise TypeError(f'"{url}" should end with {support_image_extensions}.')
def _to_openai_audio_data(source: URLSource | Base64Source) -> dict:
"""Covert an audio source to OpenAI format."""
if source["type"] == "url":
extension = source["url"].split(".")[-1].lower()
if extension not in ["wav", "mp3"]:
raise TypeError(
f"Unsupported audio file extension: {extension}, "
"wav and mp3 are supported.",
)
parsed_url = urlparse(source["url"])
if os.path.exists(source["url"]):
with open(source["url"], "rb") as audio_file:
data = base64.b64encode(audio_file.read()).decode("utf-8")
# web url
elif parsed_url.scheme != "":
response = requests.get(source["url"])
response.raise_for_status()
data = base64.b64encode(response.content).decode("utf-8")
else:
raise ValueError(
f"Unsupported audio source: {source['url']}, "
"it should be a local file or a web URL.",
)
return {
"data": data,
"format": extension,
}
if source["type"] == "base64":
data = source["data"]
media_type = source["media_type"]
if media_type not in ["audio/wav", "audio/mp3"]:
raise TypeError(
f"Unsupported audio media type: {media_type}, "
"only audio/wav and audio/mp3 are supported.",
)
return {
"data": data,
"format": media_type.split("/")[-1],
}
raise TypeError(f"Unsupported audio source: {source['type']}.")
class OpenAIChatFormatter(TruncatedFormatterBase):
"""The OpenAI formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `name` field in OpenAI API to
identify different entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversation"""
support_vision: bool = True
"""Whether support vision models"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
"""Supported message blocks for OpenAI API"""
def __init__(
self,
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the OpenAI chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.promote_tool_result_images = promote_tool_result_images
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into OpenAI API required format.
Args:
msgs (`list[Msg]`):
The list of Msg objects to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries, where each dictionary has "name",
"role", and "content" keys.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
i = 0
while i < len(msgs):
msg = msgs[i]
content_blocks = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": ( # type: ignore[arg-type]
textual_output
),
"name": block.get("name"),
},
)
# Then, handle the multimodal data if any
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
"the image contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
elif typ == "image":
content_blocks.append(
_format_openai_image_block(
block, # type: ignore[arg-type]
),
)
elif typ == "audio":
# Filter out audio content when the multimodal model
# outputs both text and audio, to prevent errors in
# subsequent model calls
if msg.role == "assistant":
continue
input_audio = _to_openai_audio_data(block["source"])
content_blocks.append(
{
"type": "input_audio",
"input_audio": input_audio,
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
msg_openai = {
"role": msg.role,
"name": msg.name,
"content": content_blocks or None,
}
if tool_calls:
msg_openai["tool_calls"] = tool_calls
# When both content and tool_calls are None, skipped
if msg_openai["content"] or msg_openai.get("tool_calls"):
messages.append(msg_openai)
# Move to next message
i += 1
return messages
class OpenAIMultiAgentFormatter(TruncatedFormatterBase):
"""
OpenAI formatter for multi-agent conversations, where more than
a user and an agent are involved.
.. tip:: This formatter is compatible with OpenAI API and
OpenAI-compatible services like vLLM, Azure OpenAI, and others.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversation"""
support_vision: bool = True
"""Whether support vision models"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
"""Supported message blocks for OpenAI API"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the OpenAI multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the OpenAI API."""
return await OpenAIChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the OpenAI API."""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required OpenAI format
formatted_msgs: list[dict] = []
conversation_blocks: list = []
accumulated_text = []
images = []
audios = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
images.append(_format_openai_image_block(block))
elif block["type"] == "audio":
# Filter out audio content when the multimodal model
# outputs both text and audio, to prevent errors in
# subsequent model calls
if msg.role == "assistant":
continue
input_audio = _to_openai_audio_data(block["source"])
audios.append(
{
"type": "input_audio",
"input_audio": input_audio,
},
)
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
content_list: list[dict[str, Any]] = []
if conversation_blocks_text:
content_list.append(
{
"type": "text",
"text": conversation_blocks_text,
},
)
if images:
content_list.extend(images)
if audios:
content_list.extend(audios)
user_message = {
"role": "user",
"content": content_list,
}
if content_list:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
"""The truncated formatter base class, which allows to truncate the input
messages."""
from abc import ABC
from copy import deepcopy
from typing import (
Any,
Tuple,
Literal,
AsyncGenerator,
)
from ._formatter_base import FormatterBase
from ..message import Msg
from ..token import TokenCounterBase
from ..tracing import trace_format
class TruncatedFormatterBase(FormatterBase, ABC):
"""Base class for truncated formatters, which formats input messages into
required formats with tokens under a specified limit."""
def __init__(
self,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the TruncatedFormatterBase.
Args:
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
self.token_counter = token_counter
assert (
max_tokens is None or 0 < max_tokens
), "max_tokens must be greater than 0"
self.max_tokens = max_tokens
@trace_format
async def format(
self,
msgs: list[Msg],
**kwargs: Any,
) -> list[dict[str, Any]]:
"""Format the input messages into the required format. If token
counter and max token limit are provided, the messages will be
truncated to fit the limit.
Args:
msgs (`list[Msg]`):
The input messages to be formatted.
Returns:
`list[dict[str, Any]]`:
The formatted messages in the required format.
"""
# Check if the input messages are valid
self.assert_list_of_msgs(msgs)
msgs = deepcopy(msgs)
while True:
formatted_msgs = await self._format(msgs)
n_tokens = await self._count(formatted_msgs)
if (
n_tokens is None
or self.max_tokens is None
or n_tokens <= self.max_tokens
):
return formatted_msgs
# truncate the input messages
msgs = await self._truncate(msgs)
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
"""Format the input messages into the required format. This method
should be implemented by the subclasses."""
formatted_msgs = []
start_index = 0
if len(msgs) > 0 and msgs[0].role == "system":
formatted_msgs.append(
await self._format_system_message(msgs[0]),
)
start_index = 1
is_first_agent_message = True
async for typ, group in self._group_messages(msgs[start_index:]):
match typ:
case "tool_sequence":
formatted_msgs.extend(
await self._format_tool_sequence(group),
)
case "agent_message":
formatted_msgs.extend(
await self._format_agent_message(
group,
is_first_agent_message,
),
)
is_first_agent_message = False
return formatted_msgs
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the LLM API.
.. note:: This is the default implementation. For certain LLM APIs
with specific requirements, you may need to implement a custom
formatting function to accommodate those particular needs.
"""
return {
"role": "system",
"content": msg.get_content_blocks("text"),
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the LLM API."""
raise NotImplementedError(
"_format_tool_sequence is not implemented",
)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the LLM API."""
raise NotImplementedError(
"_format_agent_message is not implemented",
)
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
"""Truncate the input messages, so that it can fit the token limit.
This function is called only when
- both `token_counter` and `max_tokens` are provided,
- the formatted output of the input messages exceeds the token limit.
.. tip:: This function only provides a simple strategy, and developers
can override this method to implement more sophisticated
truncation strategies.
.. note:: The tool call message should be truncated together with
its corresponding tool result message to satisfy the LLM API
requirements.
Args:
msgs (`list[Msg]`):
The input messages to be truncated.
Raises:
`ValueError`:
If the system prompt message already exceeds the token limit,
or if there are tool calls without corresponding tool results.
Returns:
`list[Msg]`:
The truncated messages.
"""
start_index = 0
if len(msgs) > 0 and msgs[0].role == "system":
if len(msgs) == 1:
# If the system prompt already exceeds the token limit, we
# raise an error.
raise ValueError(
f"The system prompt message already exceeds the token "
f"limit ({self.max_tokens} tokens).",
)
start_index = 1
# Create a tool call IDs queues to delete the corresponding tool
# result message
tool_call_ids = set()
for i in range(start_index, len(msgs)):
msg = msgs[i]
for block in msg.get_content_blocks("tool_use"):
tool_call_ids.add(block["id"])
for block in msg.get_content_blocks("tool_result"):
try:
tool_call_ids.remove(block["id"])
except KeyError:
pass
# We can stop truncating if the queue is empty
if len(tool_call_ids) == 0:
return msgs[:start_index] + msgs[i + 1 :]
if len(tool_call_ids) > 0:
raise ValueError(
"The input messages contains tool call(s) that do not have "
f"the corresponding tool result(s): {tool_call_ids}. ",
)
return msgs[:start_index]
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
"""Count the number of tokens in the input messages. If token counter
is not provided, `None` will be returned.
Args:
msgs (`list[Msg]`):
The input messages to count tokens for.
"""
if self.token_counter is None:
return None
return await self.token_counter.count(msgs)
@staticmethod
async def _group_messages(
msgs: list[Msg],
) -> AsyncGenerator[
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
None,
]:
"""Group the input messages into two types and yield them as a
generator. The two types are:
- agent message that doesn't contain tool calls/results, and
- tool sequence that consisted of a sequence of tool calls/results
.. note:: The group operation is used in multi-agent scenario, where
multiple entities are involved in the input messages. So that to be
compatible with tools API, we have to group the messages and format
them with different strategies.
Args:
msgs (`list[Msg]`):
The input messages to be grouped, where the system prompt
message shouldn't be included.
Yields:
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
A generator that yields tuples of group type and the list of
messages in that group. The group type can be either
"tool_sequence" or "agent_message".
"""
group_type: Literal["tool_sequence", "agent_message"] | None = None
group = []
for msg in msgs:
if group_type is None:
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
group_type = "tool_sequence"
else:
group_type = "agent_message"
group.append(msg)
continue
# determine if this msg has the same type as the current group
if group_type == "tool_sequence":
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
group.append(msg)
else:
yield group_type, group
group = [msg]
group_type = "agent_message"
elif group_type == "agent_message":
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
yield group_type, group
group = [msg]
group_type = "tool_sequence"
else:
group.append(msg)
if group_type:
yield group_type, group

View File

@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
"""The built-in hook functions in agentscope."""
from functools import partial
from ._studio_hooks import (
as_studio_forward_message_pre_print_hook,
)
from .. import _config
from ..agent import AgentBase
__all__ = [
"as_studio_forward_message_pre_print_hook",
]
def _equip_as_studio_hooks(
studio_url: str,
) -> None:
"""Connect to the agentscope studio."""
AgentBase.register_class_hook(
"pre_print",
"as_studio_forward_message_pre_print_hook",
partial(
as_studio_forward_message_pre_print_hook,
studio_url=studio_url,
run_id=_config.run_id,
),
)

View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
"""The studio related hook functions in agentscope."""
from typing import Any
import requests
import shortuuid
from ..agent import AgentBase, UserAgent
def as_studio_forward_message_pre_print_hook(
self: AgentBase,
kwargs: dict[str, Any],
studio_url: str,
run_id: str,
) -> None:
"""The pre-speak hook to forward messages to the studio."""
# Disable console output if needed
if self._disable_console_output: # pylint: disable=protected-access
return
msg = kwargs["msg"]
message_data = msg.to_dict()
if hasattr(self, "_reply_id"):
reply_id = getattr(self, "_reply_id")
else:
reply_id = shortuuid.uuid()
n_retry = 0
while True:
try:
res = requests.post(
f"{studio_url}/trpc/pushMessage",
json={
"runId": run_id,
"replyId": reply_id,
"replyName": getattr(self, "name", msg.name),
"replyRole": "user"
if isinstance(self, UserAgent)
else "assistant",
"msg": message_data,
},
)
res.raise_for_status()
break
except Exception as e:
if n_retry < 3:
n_retry += 1
continue
raise e from None

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""The MCP module in AgentScope, that provides fine-grained control over
the MCP servers."""
from ._client_base import MCPClientBase
from ._mcp_function import MCPToolFunction
from ._stateful_client_base import StatefulClientBase
from ._stdio_stateful_client import StdIOStatefulClient
from ._http_stateless_client import HttpStatelessClient
from ._http_stateful_client import HttpStatefulClient
__all__ = [
"MCPToolFunction",
"MCPClientBase",
"StatefulClientBase",
"StdIOStatefulClient",
"HttpStatelessClient",
"HttpStatefulClient",
]

View File

@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
"""The base class for MCP clients in AgentScope."""
from abc import abstractmethod
from typing import Callable, List
import mcp.types
from .._logging import logger
from ..message import (
ImageBlock,
Base64Source,
AudioBlock,
TextBlock,
VideoBlock,
)
class MCPClientBase:
"""Base class for MCP clients."""
def __init__(self, name: str) -> None:
"""Initialize the MCP client with a name.
Args:
name (`str`):
The name to identify the MCP server, which should be unique
across the MCP servers.
"""
self.name = name
@abstractmethod
async def get_callable_function(
self,
func_name: str,
wrap_tool_result: bool = True,
) -> Callable:
"""Get a tool function by its name."""
@staticmethod
def _convert_mcp_content_to_as_blocks(
mcp_content_blocks: list,
) -> List[TextBlock | ImageBlock | AudioBlock | VideoBlock]:
"""Convert MCP content to AgentScope blocks."""
as_content: list = []
for content in mcp_content_blocks:
if isinstance(content, mcp.types.TextContent):
as_content.append(
TextBlock(
type="text",
text=content.text,
),
)
elif isinstance(content, mcp.types.ImageContent):
as_content.append(
ImageBlock(
type="image",
source=Base64Source(
type="base64",
media_type=content.mimeType,
data=content.data,
),
),
)
elif isinstance(content, mcp.types.AudioContent):
as_content.append(
AudioBlock(
type="audio",
source=Base64Source(
type="base64",
media_type=content.mimeType,
data=content.data,
),
),
)
elif isinstance(content, mcp.types.EmbeddedResource):
if isinstance(
content.resource,
mcp.types.TextResourceContents,
):
as_content.append(
TextBlock(
type="text",
text=content.resource.model_dump_json(indent=2),
),
)
else:
# TODO: support the BlobResourceContents in the future,
# which is a base64-encoded string representing the
# binary data
logger.error(
"Unsupported EmbeddedResource content type: %s. "
"Skipping this content.",
type(content.resource),
)
else:
logger.warning(
"Unsupported content type: %s. Skipping this content.",
type(content),
)
return as_content

View File

@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""The MCP stateful HTTP client module in AgentScope."""
from typing import Any, Literal
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from ._stateful_client_base import StatefulClientBase
class HttpStatefulClient(StatefulClientBase):
"""The stateful sse/streamable HTTP MCP client implementation in
AgentScope.
.. tip:: The stateful client is recommended for MCP servers that need to
maintain session states, e.g. web browsers or other interactive
MCP servers.
.. note:: The stateful client will maintain one session across multiple
tool calls, until the client is closed by explicitly calling the
`close()` method.
.. note:: When multiple HttpStatefulClient instances are connected,
they should be closed following the Last In First Out (LIFO) principle
to avoid potential errors. Always close the most recently registered
client first, then work backwards to the first one.
For more details, please refer to this `issue
<https://github.com/modelcontextprotocol/python-sdk/issues/577>`_.
"""
def __init__(
self,
name: str,
transport: Literal["streamable_http", "sse"],
url: str,
headers: dict[str, str] | None = None,
timeout: float = 30,
sse_read_timeout: float = 60 * 5,
**client_kwargs: Any,
) -> None:
"""Initialize the streamable HTTP MCP client.
Args:
name (`str`):
The name to identify the MCP server, which should be unique
across the MCP servers.
transport (`Literal["streamable_http", "sse"]`):
The transport type of MCP server. Generally, the URL of sse
transport should end with `/sse`, while the streamable HTTP
URL ends with `/mcp`.
url (`str`):
The URL to the MCP server.
headers (`dict[str, str] | None`, optional):
Additional headers to include in the HTTP request.
timeout (`float`, optional):
The timeout for the HTTP request in seconds. Defaults to 30.
sse_read_timeout (`float`, optional):
The timeout for reading Server-Sent Events (SSE) in seconds.
Defaults to 300 (5 minutes).
**client_kwargs (`Any`):
The additional keyword arguments to pass to the streamable
HTTP client.
"""
super().__init__(name=name)
assert transport in ["streamable_http", "sse"]
self.transport = transport
if self.transport == "streamable_http":
self.client = streamablehttp_client(
url=url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
**client_kwargs,
)
else:
self.client = sse_client(
url=url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
**client_kwargs,
)

View File

@@ -0,0 +1,152 @@
# -*- coding: utf-8 -*-
"""The MCP streamable HTTP server."""
from contextlib import _AsyncGeneratorContextManager
from typing import Any, Callable, Awaitable, Literal, List
import mcp.types
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from . import MCPToolFunction
from ._client_base import MCPClientBase
from ..tool import ToolResponse
class HttpStatelessClient(MCPClientBase):
"""The sse/streamable HTTP MCP client implementation in AgentScope.
.. note:: Note this client is stateless, meaning it won't maintain the
session state across multiple tool calls. Each tool call will start a
new session and close it after the call is done.
"""
stateful: bool = False
"""Whether the MCP server is stateful, meaning it will maintain the
session state across multiple tool calls, or stateless, meaning it
will start a new session for each tool call."""
def __init__(
self,
name: str,
transport: Literal["streamable_http", "sse"],
url: str,
headers: dict[str, str] | None = None,
timeout: float = 30,
sse_read_timeout: float = 60 * 5,
**client_kwargs: Any,
) -> None:
"""Initialize the streamable HTTP MCP server.
Args:
name (`str`):
The name to identify the MCP server, which should be unique
across the MCP servers.
transport (`Literal["streamable_http", "sse"]`):
The transport type of MCP server. Generally, the URL of sse
transport should end with `/sse`, while the streamable HTTP
URL ends with `/mcp`.
url (`str`):
The URL of the MCP server.
headers (`dict[str, str] | None`, optional):
Additional headers to include in the HTTP request.
timeout (`float`, optional):
The timeout for the HTTP request in seconds. Defaults to 30.
sse_read_timeout (`float`, optional):
The timeout for reading Server-Sent Events (SSE) in seconds.
Defaults to 300 (5 minutes).
**client_kwargs (`Any`):
The additional keyword arguments to pass to the streamable
HTTP client.
"""
super().__init__(name=name)
assert transport in ["streamable_http", "sse"]
self.transport = transport
self.client_config = {
"url": url,
"headers": headers or {},
"timeout": timeout,
"sse_read_timeout": sse_read_timeout,
**client_kwargs,
}
self._tools = None
def get_client(self) -> _AsyncGeneratorContextManager[Any]:
"""The disposable MCP client object, which is a context manager."""
if self.transport == "sse":
return sse_client(**self.client_config)
if self.transport == "streamable_http":
return streamablehttp_client(**self.client_config)
raise ValueError(
f"Unsupported transport type: {self.transport}. "
"Supported types are 'sse' and 'streamable_http'.",
)
async def get_callable_function(
self,
func_name: str,
wrap_tool_result: bool = True,
execution_timeout: float | None = None,
) -> Callable[..., Awaitable[mcp.types.CallToolResult | ToolResponse]]:
"""Get a tool function by its name.
Args:
func_name (`str`):
The name of the tool function.
wrap_tool_result (`bool`, defaults to `True`):
Whether to wrap the tool result into agentscope's
`ToolResponse` object. If `False`, the raw result type
`mcp.types.CallToolResult` will be returned.
execution_timeout (`float | None`, optional):
The preset timeout in seconds for calling the tool function.
Returns:
`Callable[..., Awaitable[mcp.types.CallToolResult | \
ToolResponse]]`:
An async tool function that returns either
`mcp.types.CallToolResult` or `ToolResponse` when called.
"""
if self._tools is None:
await self.list_tools()
target_tool = None
for tool in self._tools:
if tool.name == func_name:
target_tool = tool
break
if target_tool is None:
raise ValueError(
f"Tool '{func_name}' not found in the MCP server ",
)
return MCPToolFunction(
mcp_name=self.name,
tool=target_tool,
wrap_tool_result=wrap_tool_result,
client_gen=self.get_client,
timeout=execution_timeout,
)
async def list_tools(self) -> List[mcp.types.Tool]:
"""List all tools available on the MCP server.
Returns:
`mcp.types.ListToolsResult`:
The result containing the list of tools.
"""
async with self.get_client() as cli:
read_stream, write_stream = cli[0], cli[1]
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
res = await session.list_tools()
self._tools = res.tools
return res.tools

View File

@@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
"""The MCP tool function class in AgentScope."""
from contextlib import _AsyncGeneratorContextManager
from datetime import timedelta
from typing import Any, Callable
import mcp
from mcp import ClientSession
from ._client_base import MCPClientBase
from .._utils._common import _extract_json_schema_from_mcp_tool
from ..tool import ToolResponse
class MCPToolFunction:
"""An MCP tool function class that can be called directly."""
name: str
"""The name of the tool function."""
description: str
"""The description of the tool function."""
json_schema: dict[str, Any]
"""JSON schema of the tool function"""
def __init__(
self,
mcp_name: str,
tool: mcp.types.Tool,
wrap_tool_result: bool,
client_gen: Callable[..., _AsyncGeneratorContextManager[Any]]
| None = None,
session: ClientSession | None = None,
timeout: float | None = None,
) -> None:
"""Initialize the MCP function.
Args:
mcp_name (`str`):
The name of the MCP instance.
tool (`mcp.types.Tool`):
The MCP tool definition.
wrap_tool_result (`bool`):
Whether to wrap the tool result into `ToolResponse` in
AgentScope.
client_gen (`Callable[..., _AsyncGeneratorContextManager[Any]] | \
None`, *optional*):
The MCP client generator function. Either this or `session`
must be provided.
session (`ClientSession | None`, *optional*):
The MCP client session. Either this or `client_gen` must be
provided.
timeout (`float | None`, *optional*):
The timeout in seconds for tool execution. If not provided,
no timeout will be set.
"""
self.mcp_name = mcp_name
self.name = tool.name
self.description = tool.description
self.json_schema = _extract_json_schema_from_mcp_tool(tool)
self.wrap_tool_result = wrap_tool_result
if timeout:
self.timeout = timedelta(seconds=timeout)
else:
self.timeout = None
# Cannot be None at the same time
if (
client_gen is None
and session is None
or (client_gen is not None and session is not None)
):
raise ValueError(
"Either client or session must be provided, but not both.",
)
self.client_gen = client_gen
self.session = session
async def __call__(
self,
**kwargs: Any,
) -> mcp.types.CallToolResult | ToolResponse:
"""Call the MCP tool function with the given arguments, and return
the result."""
if self.client_gen:
async with self.client_gen() as cli:
read_stream, write_stream = cli[0], cli[1]
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
res = await session.call_tool(
self.name,
arguments=kwargs,
read_timeout_seconds=self.timeout,
)
else:
res = await self.session.call_tool(
self.name,
arguments=kwargs,
read_timeout_seconds=self.timeout,
)
if self.wrap_tool_result:
as_content = MCPClientBase._convert_mcp_content_to_as_blocks(
res.content,
)
return ToolResponse(
content=as_content,
metadata=res.meta,
)
return res

View File

@@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
"""The base MCP stateful client class in AgentScope, that provides basic
functionality for stateful MCP clients."""
from abc import ABC
from contextlib import AsyncExitStack
from typing import List
import mcp
from mcp import ClientSession
from ._client_base import MCPClientBase
from ._mcp_function import MCPToolFunction
from .._logging import logger
class StatefulClientBase(MCPClientBase, ABC):
"""The base class for stateful MCP clients in AgentScope, which maintains
the session state across multiple tool calls.
The developers should use `connect()` and `close()` methods to manage
the client lifecycle.
"""
is_connected: bool
"""If connected to the MCP server"""
def __init__(self, name: str) -> None:
"""Initialize the stateful MCP client.
Args:
name (`str`):
The name to identify the MCP server, which should be unique
across the MCP servers.
"""
super().__init__(name=name)
self.client = None
self.stack = None
self.session = None
self.is_connected = False
# Cache the tools to avoid fetching them multiple times
self._cached_tools = None
async def connect(self) -> None:
"""Connect to MCP server."""
if self.is_connected:
raise RuntimeError(
"The MCP server is already connected. Call close() "
"before connecting again.",
)
self.stack = AsyncExitStack()
try:
context = await self.stack.enter_async_context(
self.client,
)
read_stream, write_stream = context[0], context[1]
self.session = ClientSession(read_stream, write_stream)
await self.stack.enter_async_context(self.session)
await self.session.initialize()
self.is_connected = True
logger.info("MCP client connected.")
except Exception:
await self.stack.aclose()
self.stack = None
raise
async def close(self, ignore_errors: bool = True) -> None:
"""Clean up the MCP client resources. You must call this method when
your application is done.
Args:
ignore_errors (`bool`):
Whether to ignore errors during cleanup. Defaults to `True`.
"""
if not self.is_connected:
raise RuntimeError(
"The MCP server is not connected. Call connect() before "
"closing.",
)
try:
await self.stack.aclose()
except Exception as e:
if not ignore_errors:
raise e
logger.warning("Error during MCP client cleanup: %s", e)
finally:
self.stack = None
self.session = None
self.is_connected = False
async def list_tools(self) -> List[mcp.types.Tool]:
"""Get all available tools from the server.
Returns:
`mcp.types.ListToolsResult`:
A list of available MCP tools.
"""
self._validate_connection()
res = await self.session.list_tools()
# Cache the tools for later use
self._cached_tools = res.tools
return res.tools
async def get_callable_function(
self,
func_name: str,
wrap_tool_result: bool = True,
execution_timeout: float | None = None,
) -> MCPToolFunction:
"""Get an async tool function from the MCP server by its name, so
that you can call it directly, wrap it into your own function, or
anyway you like.
.. note:: Currently, only the text, image, and audio results are
supported in this function.
Args:
func_name (`str`):
The name of the tool function to get.
wrap_tool_result (`bool`):
Whether to wrap the tool result into agentscope's
`ToolResponse` object. If `False`, the raw result type
`mcp.types.CallToolResult` will be returned.
execution_timeout (`float | None`, optional):
The preset timeout in seconds for calling the tool function.
Returns:
`MCPToolFunction`:
A callable async function that returns either
`mcp.types.CallToolResult` or `ToolResponse` when called.
"""
self._validate_connection()
if self._cached_tools is None:
await self.list_tools()
target_tool = None
for tool in self._cached_tools:
if tool.name == func_name:
target_tool = tool
break
if target_tool is None:
raise ValueError(
f"Tool '{func_name}' not found in the MCP server",
)
return MCPToolFunction(
mcp_name=self.name,
tool=target_tool,
wrap_tool_result=wrap_tool_result,
session=self.session,
timeout=execution_timeout,
)
def _validate_connection(self) -> None:
"""Validate the connection to the MCP server."""
if not self.is_connected:
raise RuntimeError(
"The connection is not established. Call connect() "
"before using the client.",
)
if not self.session:
raise RuntimeError(
"The session is not initialized. Call connect() "
"before using the client.",
)

View File

@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
"""The StdIO MCP server implementation in AgentScope, which provides
function-level fine-grained control over the MCP servers using standard IO."""
from typing import Literal
from mcp import stdio_client, StdioServerParameters
from ._stateful_client_base import StatefulClientBase
class StdIOStatefulClient(StatefulClientBase):
"""A client class that sets up and manage StdIO MCP server connections, and
provides function-level fine-grained control over the MCP servers.
.. tip:: The stateful client is recommended for MCP servers that need to
maintain session states, e.g. web browsers or other interactive
MCP servers.
.. note:: The stateful client will maintain one session across multiple
tool calls, until the client is closed by explicitly calling the
`close()` method.
.. note:: When multiple StdIOStatefulClient instances are connected,
they should be closed following the Last In First Out (LIFO) principle
to avoid potential errors. Always close the most recently registered
client first, then work backwards to the first one.
For more details, please refer to this `issue
<https://github.com/modelcontextprotocol/python-sdk/issues/577>`_.
"""
def __init__(
self,
name: str,
command: str,
args: list[str] | None = None,
env: dict[str, str] | None = None,
cwd: str | None = None,
encoding: str = "utf-8",
encoding_error_handler: Literal[
"strict",
"ignore",
"replace",
] = "strict",
) -> None:
"""Initialize the MCP server with std IO.
Args:
name (`str`):
The name to identify the MCP server, which should be unique
across the MCP servers.
command (`str`):
The executable to run to start the server.
args (`list[str] | None`, optional):
Command line arguments to pass to the executable.
env (`dict[str, str] | None`, optional):
The environment to use when spawning the process.
cwd (`str | None`, optional):
The working directory to use when spawning the process.
encoding (`str`, optional):
The text encoding used when sending/receiving messages to the
server. Defaults to "utf-8".
encoding_error_handler (`Literal["strict", "ignore", "replace"]`, \
defaults to "strict"):
The text encoding error handler.
"""
super().__init__(name=name)
self.client = stdio_client(
StdioServerParameters(
command=command,
args=args or [],
env=env,
cwd=cwd,
encoding=encoding,
encoding_error_handler=encoding_error_handler,
),
)

View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
"""The memory module."""
from ._working_memory import (
MemoryBase,
InMemoryMemory,
RedisMemory,
AsyncSQLAlchemyMemory,
)
from ._long_term_memory import (
LongTermMemoryBase,
Mem0LongTermMemory,
ReMePersonalLongTermMemory,
ReMeTaskLongTermMemory,
ReMeToolLongTermMemory,
)
__all__ = [
# Working memory
"MemoryBase",
"InMemoryMemory",
"RedisMemory",
"AsyncSQLAlchemyMemory",
# Long-term memory
"LongTermMemoryBase",
"Mem0LongTermMemory",
"ReMePersonalLongTermMemory",
"ReMeTaskLongTermMemory",
"ReMeToolLongTermMemory",
]

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""The long-term memory module for AgentScope."""
from ._long_term_memory_base import LongTermMemoryBase
from ._mem0 import Mem0LongTermMemory
from ._reme import (
ReMePersonalLongTermMemory,
ReMeTaskLongTermMemory,
ReMeToolLongTermMemory,
)
__all__ = [
"LongTermMemoryBase",
"Mem0LongTermMemory",
"ReMePersonalLongTermMemory",
"ReMeTaskLongTermMemory",
"ReMeToolLongTermMemory",
]

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""The long-term memory base class."""
from typing import Any
from agentscope.message import Msg
from agentscope.module import StateModule
from agentscope.tool import ToolResponse
class LongTermMemoryBase(StateModule):
"""The long-term memory base class, which should be a time-series
memory management system.
The `record_to_memory` and `retrieve_from_memory` methods are two tool
functions for agent to manage the long-term memory voluntarily. You can
choose not to implement these two functions.
The `record` and `retrieve` methods are for developers to use. For example,
retrieving/recording memory at the beginning of each reply, and adding
the retrieved memory to the system prompt.
"""
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> Any:
"""A developer-designed method to record information from the given
input message(s) to the long-term memory."""
raise NotImplementedError(
"The `record` method is not implemented. ",
)
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""A developer-designed method to retrieve information from the
long-term memory based on the given input message(s). The retrieved
information will be added to the system prompt of the agent."""
raise NotImplementedError(
"The `retrieve` method is not implemented. ",
)
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Use this function to record important information that you may
need later. The target content should be specific and concise, e.g.
who, when, where, do what, why, how, etc.
Args:
thinking (`str`):
Your thinking and reasoning about what to record
content (`list[str]`):
The content to remember, which is a list of strings.
"""
raise NotImplementedError(
"The `record_to_memory` method is not implemented. "
"You can implement it in your own long-term memory class.",
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Retrieve the memory based on the given keywords.
Args:
keywords (`list[str]`):
The keywords to search for in the memory, which should be
specific and concise, e.g. the person's name, the date, the
location, etc.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 5.
Returns:
`list[Msg]`:
A list of messages that match the keywords.
"""
raise NotImplementedError(
"The `retrieve_from_memory` method is not implemented. "
"You can implement it in your own long-term memory class.",
)

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""The Mem0 long-term memory module for AgentScope."""
from ._mem0_long_term_memory import Mem0LongTermMemory
__all__ = [
"Mem0LongTermMemory",
]

View File

@@ -0,0 +1,746 @@
# -*- coding: utf-8 -*-
"""Long-term memory implementation using mem0 library.
This module provides a long-term memory implementation that integrates
with the mem0 library to provide persistent memory storage and retrieval
capabilities for AgentScope agents.
"""
import asyncio
import json
from typing import Any, TYPE_CHECKING
from importlib import metadata
from pydantic import field_validator
from ....embedding import EmbeddingModelBase
from .._long_term_memory_base import LongTermMemoryBase
from ....message import Msg, TextBlock
from ....model import ChatModelBase
from ....tool import ToolResponse
if TYPE_CHECKING:
from mem0.configs.base import MemoryConfig
from mem0.vector_stores.configs import VectorStoreConfig
else:
MemoryConfig = Any
VectorStoreConfig = Any
def _create_agentscope_config_classes() -> tuple:
"""Create custom config classes for agentscope providers."""
from mem0.embeddings.configs import EmbedderConfig
from mem0.llms.configs import LlmConfig
class _ASLlmConfig(LlmConfig):
"""Custom LLM config class that updates the validate_config method.
Attention: in mem0, the validate_config hardcodes the provider, so we
need to override the validate_config method to support the agentscope
providers. We will follow up with the mem0 to improve this.
"""
@field_validator("config")
@classmethod
def validate_config(cls, v: Any, values: Any) -> Any:
"""Validate the LLM configuration."""
from mem0.utils.factory import LlmFactory
provider = values.data.get("provider")
if provider in LlmFactory.provider_to_class:
return v
raise ValueError(f"Unsupported LLM provider: {provider}")
class _ASEmbedderConfig(EmbedderConfig):
"""Custom embedder config class that updates the validate_config
method."""
@field_validator("config")
@classmethod
def validate_config(cls, v: Any, values: Any) -> Any:
"""Validate the embedder configuration."""
from mem0.utils.factory import EmbedderFactory
provider = values.data.get("provider")
if provider in EmbedderFactory.provider_to_class:
return v
raise ValueError(f"Unsupported Embedder provider: {provider}")
return _ASLlmConfig, _ASEmbedderConfig
class Mem0LongTermMemory(LongTermMemoryBase):
"""A class that implements the LongTermMemoryBase interface using mem0."""
@staticmethod
def _setup_mem0_logging(suppress_mem0_logging: bool) -> None:
"""Suppress mem0 logging if requested.
Args:
suppress_mem0_logging (`bool`):
Whether to suppress mem0 logging. See class docstring for
details on QDRANT validation errors when using mem0 1.0.3.
"""
if suppress_mem0_logging:
import logging
logging.getLogger("mem0").setLevel(logging.CRITICAL)
logging.getLogger("mem0.memory").setLevel(logging.CRITICAL)
logging.getLogger("mem0.memory.main").setLevel(logging.CRITICAL)
@staticmethod
def _register_agentscope_providers() -> None:
"""Register the agentscope providers with mem0.
Raises:
`ImportError`:
If the mem0 library is not installed.
"""
try:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.utils.factory import LlmFactory, EmbedderFactory
from packaging import version
# Check mem0 version
current_version = metadata.version("mem0ai")
is_mem0_version_low = version.parse(
current_version,
) <= version.parse("0.1.115")
# Register the agentscope providers with mem0
EmbedderFactory.provider_to_class["agentscope"] = (
"agentscope.memory._long_term_memory._mem0."
"_mem0_utils.AgentScopeEmbedding"
)
if is_mem0_version_low:
# For mem0 version <= 0.1.115, use the old style
LlmFactory.provider_to_class["agentscope"] = (
"agentscope.memory._long_term_memory._mem0."
"_mem0_utils.AgentScopeLLM"
)
else:
# For mem0 version > 0.1.115, use the new style
LlmFactory.provider_to_class["agentscope"] = (
"agentscope.memory._long_term_memory._mem0."
"_mem0_utils.AgentScopeLLM",
BaseLlmConfig,
)
except ImportError as e:
raise ImportError(
"Please install the mem0 library by `pip install mem0ai`",
) from e
@staticmethod
def _validate_identifiers(
agent_name: str | None,
user_name: str | None,
run_name: str | None,
) -> None:
"""Validate that at least one identifier is provided.
Args:
agent_name (`str | None`):
The name of the agent.
user_name (`str | None`):
The name of the user.
run_name (`str | None`):
The name of the run/session.
Raises:
`ValueError`:
If all identifiers are None.
"""
if agent_name is None and user_name is None and run_name is None:
raise ValueError(
"at least one of agent_name, user_name, and run_name is "
"required",
)
@staticmethod
def _configure_mem0_config(
mem0_config: MemoryConfig | None,
model: ChatModelBase | None,
embedding_model: EmbeddingModelBase | None,
vector_store_config: VectorStoreConfig | None,
_ASLlmConfig: type,
_ASEmbedderConfig: type,
**kwargs: Any,
) -> MemoryConfig:
"""Configure the mem0 MemoryConfig object.
Args:
mem0_config (`MemoryConfig | None`):
The existing mem0 config, if any.
model (`ChatModelBase | None`):
The chat model to use.
embedding_model (`EmbeddingModelBase | None`):
The embedding model to use.
vector_store_config (`VectorStoreConfig | None`):
The vector store config to use.
_ASLlmConfig (`type`):
The custom LLM config class for agentscope.
_ASEmbedderConfig (`type`):
The custom embedder config class for agentscope.
**kwargs (`Any`):
Additional keyword arguments, including 'on_disk' for
vector store configuration.
Returns:
`MemoryConfig`:
The configured MemoryConfig object.
Raises:
`ValueError`:
If `mem0_config` is None and either `model` or
`embedding_model` is None.
"""
import mem0
if mem0_config is not None:
# Case 1: mem0_config is provided - override specific
# configurations if individual params are given
# Override LLM configuration if model is provided
if model is not None:
mem0_config.llm = _ASLlmConfig(
provider="agentscope",
config={"model": model},
)
# Override embedder configuration if embedding_model is provided
if embedding_model is not None:
mem0_config.embedder = _ASEmbedderConfig(
provider="agentscope",
config={"model": embedding_model},
)
# Override vector store configuration if vector_store_config is
# provided
if vector_store_config is not None:
mem0_config.vector_store = vector_store_config
else:
# Case 2: mem0_config is not provided - create new configuration
# from individual parameters
# Validate that required parameters are provided
if model is None or embedding_model is None:
raise ValueError(
"model and embedding_model are required if mem0_config "
"is not provided",
)
# Create new MemoryConfig with provided LLM and embedder
mem0_config = mem0.configs.base.MemoryConfig(
llm=_ASLlmConfig(
provider="agentscope",
config={"model": model},
),
embedder=_ASEmbedderConfig(
provider="agentscope",
config={"model": embedding_model},
),
)
# Set vector store configuration
if vector_store_config is not None:
# Use provided vector store configuration
mem0_config.vector_store = vector_store_config
else:
# Use default Qdrant configuration with on-disk storage for
# persistence set on_disk to True to enable persistence,
# otherwise it will be in memory only
on_disk = kwargs.get("on_disk", True)
mem0_config.vector_store = (
mem0.vector_stores.configs.VectorStoreConfig(
config={"on_disk": on_disk},
)
)
return mem0_config
def __init__(
self,
agent_name: str | None = None,
user_name: str | None = None,
run_name: str | None = None,
model: ChatModelBase | None = None,
embedding_model: EmbeddingModelBase | None = None,
vector_store_config: VectorStoreConfig | None = None,
mem0_config: MemoryConfig | None = None,
default_memory_type: str | None = None,
suppress_mem0_logging: bool = True,
**kwargs: Any,
) -> None:
"""Initialize the Mem0LongTermMemory instance
Args:
agent_name (`str | None`, optional):
The name of the agent. Default is None.
user_name (`str | None`, optional):
The name of the user. Default is None.
run_name (`str | None`, optional):
The name of the run/session. Default is None.
.. note::
1. At least one of `agent_name`, `user_name`, or `run_name` is
required.
2. During memory recording, these parameters become metadata
for the stored memories.
3. **Important**: mem0 will extract memories from messages
containing role of "user" by default. If you want to
extract memories from messages containing role of
"assistant", you need to provide `agent_name`.
4. During memory retrieval, only memories with matching
metadata values will be returned.
model (`ChatModelBase | None`, optional):
The chat model to use for the long-term memory. If
mem0_config is provided, this will override the LLM
configuration. If mem0_config is None, this is required.
embedding_model (`EmbeddingModelBase | None`, optional):
The embedding model to use for the long-term memory. If
mem0_config is provided, this will override the embedder
configuration. If mem0_config is None, this is required.
vector_store_config (`VectorStoreConfig | None`, optional):
The vector store config to use for the long-term memory.
If mem0_config is provided, this will override the vector store
configuration. If mem0_config is None and this is not
provided, defaults to Qdrant with on_disk=True.
mem0_config (`MemoryConfig | None`, optional):
The mem0 config to use for the long-term memory.
If provided, individual
model/embedding_model/vector_store_config parameters will
override the corresponding configurations in mem0_config. If
None, a new MemoryConfig will be created using the provided
parameters.
default_memory_type (`str | None`, optional):
The type of memory to use. Default is None, to create a
semantic memory.
suppress_mem0_logging (`bool`, optional):
Whether to suppress mem0 logging. Default is True.
.. note::
When using vector database QDRANT with mem0 1.0.3, you may
encounter validation errors:
"Error awaiting memory task (async): 6 validation errors for
PointStruct vector.list[float] Input should be a valid list
[type=list_type, ...]"
According to the mem0 community
(see https://github.com/mem0ai/mem0/issues/3780),
these error messages are harmless and can be safely ignored.
Setting `suppress_mem0_logging=True` (the default) will
suppress these error messages.
Raises:
`ValueError`:
If `mem0_config` is None and either `model` or
`embedding_model` is None.
"""
super().__init__()
# Suppress mem0 logging if requested
self._setup_mem0_logging(suppress_mem0_logging)
# Register agentscope providers with mem0
self._register_agentscope_providers()
# Create the custom config classes for agentscope providers dynamically
_ASLlmConfig, _ASEmbedderConfig = _create_agentscope_config_classes()
# Validate identifiers
self._validate_identifiers(agent_name, user_name, run_name)
# Store agent and user identifiers for memory management
self.agent_id = agent_name
self.user_id = user_name
self.run_id = run_name
# Configure mem0_config
import mem0
mem0_config = self._configure_mem0_config(
mem0_config=mem0_config,
model=model,
embedding_model=embedding_model,
vector_store_config=vector_store_config,
_ASLlmConfig=_ASLlmConfig,
_ASEmbedderConfig=_ASEmbedderConfig,
**kwargs,
)
# Initialize the async memory instance with the configured settings
self.long_term_working_memory = mem0.AsyncMemory(mem0_config)
# Store the default memory type for future use
self.default_memory_type = default_memory_type
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Use this function to record important information that you may
need later. The target content should be specific and concise, e.g.
who, when, where, do what, why, how, etc.
Args:
thinking (`str`):
Your thinking and reasoning about what to record.
content (`list[str]`):
The content to remember, which is a list of strings.
"""
# Multi-strategy recording approach to ensure content persistence:
#
# This method employs a three-tier fallback strategy to maximize
# successful memory recording:
#
# 1. Primary: Record as "user" role message
# - This is the default approach for capturing user-related
# content
# - Mem0 extracts and infers memories from messages containing
# role of "user"
#
# 2. Fallback (if agent_id exists): Record as "assistant" role
# message
# - Triggered when primary recording yields no results
# - In this case, mem0 will use the AGENT_MEMORY_EXTRACTION_PROMPT
# in mem0/mem0/configs/prompts.py to extract memories from
# messages containing role of "assistant", if agent_id is
# provided, otherwise it will use the
# USER_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
# to extract memories.
#
# 3. Last resort: Record as "assistant" with infer=False
# - Used when both previous attempts yield no results
# - Bypasses mem0's inference mechanism, which means no
# inference is performed, mem0 will only record the content
# as is.
#
# This graduated approach ensures that even if mem0's inference fails
# to extract meaningful memories, the raw content is still preserved.
try:
if thinking:
content = [thinking] + content
# Strategy 1: Record as user message first
results = await self._mem0_record(
[
{
"role": "user",
"content": "\n".join(content),
"name": "user",
},
],
**kwargs,
)
# Strategy 2: Fallback to assistant message. In this case, if
# agent_id is provided, mem0 will use the
# AGENT_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
# to extract memories from messages containing role of
# "assistant". If agent_id is not provided, mem0 will still use
# the USER_MEMORY_EXTRACTION_PROMPT in
# mem0/mem0/configs/prompts.py to extract memories.
if (
results
and isinstance(results, dict)
and "results" in results
and len(results["results"]) == 0
):
results = await self._mem0_record(
[
{
"role": "assistant",
"content": "\n".join(content),
"name": "assistant",
},
],
**kwargs,
)
# Strategy 3: Last resort - direct recording without inference.
# In this case, mem0 will not use any prompts to extract
# memories, it will only record the content as is.
if (
results
and isinstance(results, dict)
and "results" in results
and len(results["results"]) == 0
):
results = await self._mem0_record(
[
{
"role": "assistant",
"content": "\n".join(content),
"name": "assistant",
},
],
infer=False,
**kwargs,
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Successfully recorded content to memory "
f"{results}",
),
],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Retrieve the memory based on the given keywords.
Args:
keywords (`list[str]`):
Short, targeted search phrases (for example, a person's name,
a specific date, a location, or a phrase describing something
you want to retrieve from the memory). Each keyword is issued
as an independent query against the memory store.
limit (`int`, optional):
The maximum number of memories to retrieve per search,
defaults to 5.
i.e.,the number of memories to retrieve for
each keyword.
Returns:
`ToolResponse`:
A ToolResponse containing the retrieved memories as JSON text.
"""
try:
results = []
search_coroutines = [
self.long_term_working_memory.search(
query=keyword,
agent_id=self.agent_id,
user_id=self.user_id,
run_id=self.run_id,
limit=limit,
)
for keyword in keywords
]
search_results = await asyncio.gather(*search_coroutines)
for result in search_results:
if result:
results.extend(
[item["memory"] for item in result["results"]],
)
if "relations" in result.keys():
results.extend(
self._format_relations(result),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text="\n".join(results),
),
],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving memory: {str(e)}",
),
],
)
async def record(
self,
msgs: list[Msg | None],
memory_type: str | None = None,
infer: bool = True,
**kwargs: Any,
) -> dict:
"""Record the content to the long-term memory.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory.
memory_type (`str | None`, optional):
The type of memory to use. Default is None, to create a
semantic memory. "procedural_memory" is explicitly used for
procedural memories.
infer (`bool`, optional):
Whether to infer memory from the content. Default is True.
**kwargs (`Any`):
Additional keyword arguments for the mem0 recording.
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
messages = [
{
"role": "assistant",
"content": "\n".join([str(_.content) for _ in msg_list]),
"name": "assistant",
},
]
results = await self._mem0_record(
messages,
memory_type=memory_type,
infer=infer,
**kwargs,
)
return results
def _format_relations(self, result: dict) -> list:
"""Format relations from search result.
Args:
result (`dict`):
The result from the memory search operation.
Returns:
`list`:
The formatted relations.
Each relation is a string in the format of:
"{source} -- {relationship} -- {destination}"
"""
if "relations" not in result:
return []
return [
f"{relation['source']} -- "
f"{relation['relationship']} -- "
f"{relation['destination']}"
for relation in result["relations"]
]
async def _mem0_record(
self,
messages: str | list[dict],
memory_type: str | None = None,
infer: bool = True,
**kwargs: Any,
) -> dict:
"""Record the content to the long-term memory.
Args:
messages (`str`):
The content to remember, which is a string or a list of
dictionaries representing messages.
memory_type (`str | None`, optional):
The type of memory to use. Default is None, to create a
semantic memory. "procedural_memory" is explicitly used for
procedural memories.
infer (`bool`, optional):
Whether to infer memory from the content. Default is True.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`dict`:
The result from the memory recording operation.
"""
results = await self.long_term_working_memory.add(
messages=messages,
agent_id=self.agent_id,
user_id=self.user_id,
run_id=self.run_id,
memory_type=(
memory_type
if memory_type is not None
else self.default_memory_type
),
infer=infer,
**kwargs,
)
return results
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve the content from the long-term memory.
Args:
msg (`Msg | list[Msg] | None`):
The message to search for in the memory, which should be
specific and concise, e.g. the person's name, the date, the
location, etc.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. if the
message is a list of messages, the limit will be applied to
each message. If the message is a single message, then the
limit is the total number of memories to retrieve for the
message. Defaults to 5.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved memory
"""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
msg_strs = [
json.dumps(_.to_dict()["content"], ensure_ascii=False) for _ in msg
]
results = []
search_coroutines = [
self.long_term_working_memory.search(
query=item,
agent_id=self.agent_id,
user_id=self.user_id,
run_id=self.run_id,
limit=limit,
)
for item in msg_strs
]
search_results = await asyncio.gather(*search_coroutines)
for result in search_results:
if result:
results.extend(
[memory["memory"] for memory in result["results"]],
)
if "relations" in result.keys():
results.extend(
self._format_relations(result),
)
return "\n".join(results)

View File

@@ -0,0 +1,363 @@
# -*- coding: utf-8 -*-
"""Utility classes for integrating AgentScope with mem0 library.
This module provides wrapper classes that allow AgentScope models to be used
with the mem0 library for long-term memory functionality.
"""
import asyncio
import atexit
import threading
from typing import Any, Coroutine, Dict, List, Literal
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.configs.llms.base import BaseLlmConfig
from mem0.embeddings.base import EmbeddingBase
from mem0.llms.base import LLMBase
from ....embedding import EmbeddingModelBase
from ....model import ChatModelBase, ChatResponse
class _EventLoopManager:
"""Global event loop manager for running async operations in sync context.
This manager creates and maintains a persistent background event loop
that runs in a separate daemon thread. This ensures that async model
clients (like Ollama AsyncClient) remain bound to the same event loop
across multiple calls, avoiding "Event loop is closed" errors.
"""
_DEFAULT_TIMEOUT = 5.0 # Default timeout in seconds
def __init__(self) -> None:
"""Initialize the event loop manager."""
self.loop: asyncio.AbstractEventLoop | None = None
self.thread: threading.Thread | None = None
self.lock = threading.Lock()
self.loop_started = threading.Event()
# Register cleanup function to be called on program exit
atexit.register(self.cleanup)
def get_loop(self) -> asyncio.AbstractEventLoop:
"""Get or create the persistent background event loop.
Returns:
`asyncio.AbstractEventLoop`:
The persistent event loop running in a background thread.
Raises:
`RuntimeError`: If the event loop fails to start within the
timeout.
"""
with self.lock:
if self.loop is None or self.loop.is_closed():
def run_loop() -> None:
"""Run the event loop in the background thread."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Store the loop reference before starting
self.loop = loop
self.loop_started.set()
loop.run_forever()
# Clear the event before starting the thread
self.loop_started.clear()
# Create and start the background thread
self.thread = threading.Thread(
target=run_loop,
daemon=True,
name="AgentScope-AsyncLoop",
)
self.thread.start()
# Wait for the loop to be ready
if not self.loop_started.wait(timeout=self._DEFAULT_TIMEOUT):
raise RuntimeError(
"Timeout waiting for event loop to start",
)
# After waiting, self.loop should be set by the background thread
assert (
self.loop is not None
), "Event loop was not initialized properly"
return self.loop
def cleanup(self) -> None:
"""Cleanup the event loop and thread on program exit."""
with self.lock:
if self.loop is not None and not self.loop.is_closed():
# Stop the event loop gracefully
self.loop.call_soon_threadsafe(self.loop.stop)
# Wait for the thread to finish
if self.thread is not None and self.thread.is_alive():
self.thread.join(timeout=self._DEFAULT_TIMEOUT)
# Close the loop
self.loop.close()
self.loop = None
self.thread = None
# Global event loop manager instance
_event_loop_manager = _EventLoopManager()
def _run_async_in_persistent_loop(coro: Coroutine) -> Any:
"""Run an async coroutine in the persistent background event loop.
This function uses a global event loop manager to ensure that all
async operations run in the same event loop, which is crucial for
async clients like Ollama that bind to a specific event loop.
Args:
coro (`Coroutine`):
The coroutine to run.
Returns:
`Any`:
The result of the coroutine.
Raises:
`RuntimeError`:
If there's an error running the coroutine.
"""
loop = _event_loop_manager.get_loop()
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result()
class AgentScopeLLM(LLMBase):
"""Wrapper for the AgentScope LLM.
This class is a wrapper for the AgentScope LLM. It is used to generate
responses using the AgentScope LLM in mem0.
"""
def __init__(self, config: BaseLlmConfig | None = None):
"""Initialize the AgentScopeLLM wrapper.
Args:
config (`BaseLlmConfig | None`, optional):
Configuration object for the LLM. Default is None.
"""
super().__init__(config)
if self.config.model is None:
raise ValueError("`model` parameter is required")
if not isinstance(self.config.model, ChatModelBase):
raise ValueError("`model` must be an instance of ChatModelBase")
self.agentscope_model = self.config.model
def _parse_response(
self,
model_response: ChatResponse,
has_tool: bool,
) -> str | dict:
"""Parse the model response into a string or
a dict to follow the mem0 library's format.
Args:
model_response (`ChatResponse`): The response from the model.
has_tool (`bool`): Whether there are tool calls in the response.
Returns:
`str | dict`:
The parsed response. If has_tool is True, return a dict
with "content" and "tool_calls" keys. Otherwise, return
a string.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_parts = []
for block in model_response.content:
# Handle TextBlock
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(str(block.get("text", "")))
# Handle ThinkingBlock
elif isinstance(block, dict) and block.get("type") == "thinking":
thinking_parts.append(
f"[Thinking: {block.get('thinking', '')}]",
)
# Handle ToolUseBlock
elif isinstance(block, dict) and block.get("type") == "tool_use":
tool_name = block.get("name")
tool_input = block.get("input", {})
tool_parts.append(
{
"name": tool_name,
"arguments": tool_input,
},
)
text_part = thinking_parts + text_parts
if has_tool:
# If there are tool calls, return the content and tool calls
return {
"content": "\n".join(text_part) if len(text_part) > 0 else "",
"tool_calls": tool_parts,
}
else:
return "\n".join(text_part) if len(text_part) > 0 else ""
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Any | None = None,
tools: List[Dict] | None = None,
tool_choice: str = "auto",
) -> str | dict:
"""Generate a response based on the given messages using agentscope.
Args:
messages (`List[Dict[str, str]]`):
List of message dicts containing 'role' and 'content'.
response_format (`Any | None`, optional):
Format of the response. Not used in AgentScope.
tools (`List[Dict] | None`, optional):
List of tools that the model can call. Not used in AgentScope.
tool_choice (`str`, optional):
Tool choice method. Not used in AgentScope.
Returns:
`str | dict`:
The generated response.
"""
# pylint: disable=unused-argument
try:
# Convert the messages to AgentScope's format
agentscope_messages = []
for message in messages:
role = message["role"]
content = message["content"]
if role in ["system", "user", "assistant"]:
agentscope_messages.append(
{"role": role, "content": content},
)
if not agentscope_messages:
raise ValueError(
"No valid messages found in the messages list",
)
# Use the agentscope model to generate response (async call)
async def _async_call() -> ChatResponse:
# TODO: handle the streaming response or forbidden streaming
# mode
return await self.agentscope_model( # type: ignore
agentscope_messages,
tools=tools,
)
# Run in the persistent event loop
# This ensures the model client (e.g., Ollama AsyncClient)
# always runs in the same event loop, avoiding binding issues
response = _run_async_in_persistent_loop(
_async_call(),
)
has_tool = tools is not None
# Extract text from the response content blocks
if not response.content:
if has_tool:
return {
"content": "",
"tool_calls": [],
}
else:
return ""
return self._parse_response(response, has_tool)
except Exception as e:
raise RuntimeError(
f"Error generating response using agentscope model: {str(e)}",
) from e
class AgentScopeEmbedding(EmbeddingBase):
"""Wrapper for the AgentScope Embedding model.
This class is a wrapper for the AgentScope Embedding model. It is used
to generate embeddings using the AgentScope Embedding model in mem0.
"""
def __init__(self, config: BaseEmbedderConfig | None = None):
"""Initialize the AgentScopeEmbedding wrapper.
Args:
config (`BaseEmbedderConfig | None`, optional):
Configuration object for the embedder. Default is None.
"""
super().__init__(config)
if self.config.model is None:
raise ValueError("`model` parameter is required")
if not isinstance(self.config.model, EmbeddingModelBase):
raise ValueError(
"`model` must be an instance of EmbeddingModelBase",
)
self.agentscope_model = self.config.model
def embed(
self,
text: str | List[str],
memory_action: Literal[ # pylint: disable=unused-argument
"add",
"search",
"update",
]
| None = None,
) -> List[float]:
"""Get the embedding for the given text using AgentScope.
Args:
text (`str | List[str]`):
The text to embed.
memory_action (`Literal["add", "search", "update"] | None`, \
optional):
The type of embedding to use. Must be one of "add", "search",
or "update". Defaults to None.
Returns:
`List[float]`:
The embedding vector.
"""
try:
# Convert single text to list for AgentScope embedding model
text_list = [text] if isinstance(text, str) else text
# Use the agentscope model to generate embedding (async call)
async def _async_call() -> Any:
response = await self.agentscope_model(text_list)
return response
# Run in the persistent event loop
# This ensures the model client (e.g., Ollama AsyncClient)
# always runs in the same event loop, avoiding binding issues
response = _run_async_in_persistent_loop(
_async_call(),
)
# Extract the embedding vector from the first Embedding object
# response.embeddings is a list of Embedding objects
# Each Embedding object has an 'embedding' attribute containing
# the vector
embedding = response.embeddings[0]
if embedding is None:
raise ValueError("Failed to extract embedding from response")
return embedding
except Exception as e:
raise RuntimeError(
f"Error generating embedding using agentscope model: {str(e)}",
) from e

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""The reme memory module."""
from ._reme_personal_long_term_memory import ReMePersonalLongTermMemory
from ._reme_task_long_term_memory import ReMeTaskLongTermMemory
from ._reme_tool_long_term_memory import ReMeToolLongTermMemory
__all__ = [
"ReMePersonalLongTermMemory",
"ReMeTaskLongTermMemory",
"ReMeToolLongTermMemory",
]

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
"""Base long-term memory implementation using ReMe library.
This module provides a base class for long-term memory implementations
that integrate with the ReMe library. ReMe enables agents to maintain
persistent, searchable memories across sessions and contexts.
The module handles the integration between AgentScope's memory system and
the ReMe library, including:
- Model configuration and API credential management
- Context lifecycle management (async context managers)
- Graceful handling of missing dependencies
- Error handling with helpful installation instructions
Key Features:
- Supports both DashScope and OpenAI model providers
- Automatic extraction of API credentials and endpoints
- Flexible configuration via config files or kwargs
- Safe fallback behavior when reme_ai is not installed
Dependencies:
The ReMe library is an optional dependency that must be installed:
.. code-block:: bash
pip install reme-ai
For more information, visit: https://github.com/modelscope/reMe
Subclasses:
This base class is extended by specific memory type implementations:
- ReMeToolLongTermMemory: For tool execution patterns and guidelines
- ReMeTaskLongTermMemory: For task execution experiences and learnings
- ReMePersonalLongTermMemory: For user preferences and personal information
Example:
.. code-block:: python
from agentscope.models import OpenAIChatModel
from agentscope.embedding import OpenAITextEmbedding
from agentscope.memory._reme import ReMeToolLongTermMemory
# Initialize models
model = OpenAIChatModel(model_name="gpt-4", api_key="...")
embedding = OpenAITextEmbedding(
model_name="text-embedding-3-small", api_key="...")
# Create memory instance
memory = ReMeToolLongTermMemory(
agent_name="my_agent",
user_name="user_123",
model=model,
embedding_model=embedding
)
# Use memory in async context
async with memory:
# Record tool execution
await memory.record_to_memory(
thinking="This tool worked well for data processing",
content=['{"tool_name": "process_data", "success": true, ...}']
)
# Retrieve tool guidelines
result = await memory.retrieve_from_memory(
keywords=["process_data"]
)
"""
from abc import ABCMeta
from typing import Any
from .._long_term_memory_base import LongTermMemoryBase
from ....embedding import DashScopeTextEmbedding, OpenAITextEmbedding
from ....model import DashScopeChatModel, OpenAIChatModel
class ReMeLongTermMemoryBase(LongTermMemoryBase, metaclass=ABCMeta):
"""Base class for ReMe-based long-term memory implementations.
This class provides the foundation for integrating AgentScope with the ReMe
library, enabling agents to maintain and retrieve long-term memories across
different contexts.
The ReMe library must be installed separately:
pip install reme-ai
If the library is not installed, a warning will be issued during
initialization,
and runtime errors with installation instructions will be raised
when attempting
to use memory operations.
"""
def __init__(
self,
agent_name: str | None = None,
user_name: str | None = None,
run_name: str | None = None,
model: DashScopeChatModel | OpenAIChatModel | None = None,
embedding_model: (
DashScopeTextEmbedding | OpenAITextEmbedding | None
) = None,
reme_config_path: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize the ReMe-based long-term memory.
This constructor sets up the connection to the ReMe
library and configures
the necessary models for memory operations. The ReMe app
will be initialized
with the provided model configurations.
Args:
agent_name (`str | None`, optional):
Name identifier for the agent. Used for organizing
memories by agent.
user_name (`str | None`, optional):
Unique identifier for the user or workspace. This maps
to workspace_id in ReMe and helps isolate memories across
different users/workspaces.
run_name (`str | None`, optional):
Name identifier for the current execution run or session.
model (`DashScopeChatModel | OpenAIChatModel | None`, optional):
The chat model to use for memory operations. The model's
API credentials and endpoint will be extracted and
passed to ReMe.
embedding_model (`DashScopeTextEmbedding | OpenAITextEmbedding | \
None`, optional):
The embedding model to use for semantic memory retrieval.
The model's API credentials and endpoint will be
extracted and passed to ReMe.
reme_config_path (`str | None`, optional):
Path to a custom ReMe configuration file. If not provided, ReMe
will use its default configuration.
**kwargs (`Any`):
Additional keyword arguments to pass to the
ReMeApp constructor.
These can include custom ReMe configuration parameters.
Raises:
`ValueError`:
If the provided model is not a DashScopeChatModel or
OpenAIChatModel, or if the embedding_model is not a
DashScopeTextEmbedding or OpenAITextEmbedding.
Note:
If the reme_ai library is not installed, a warning will be
issued and self.app will be set to None. Subsequent memory
operations will raise RuntimeError with installation
instructions.
Example:
.. code-block:: python
from agentscope.models import OpenAIChatModel
from agentscope.embedding import OpenAITextEmbedding
from agentscope.memory._reme import ReMeToolLongTermMemory
# Initialize models
model = OpenAIChatModel(
model_name="gpt-4",
api_key="your-api-key"
)
embedding = OpenAITextEmbedding(
model_name="text-embedding-3-small",
api_key="your-api-key"
)
# Create memory instance
memory = ReMeToolLongTermMemory(
agent_name="my_agent",
user_name="user_123",
run_name="session_001",
model=model,
embedding_model=embedding
)
# Use with async context manager
async with memory:
# Memory operations...
pass
"""
super().__init__()
# Store agent and workspace identifiers
self.agent_name = agent_name
# Maps to ReMe's workspace_id concept
self.workspace_id = user_name
self.run_name = run_name
# Build configuration arguments for ReMeApp
# These will be passed as command-line style config overrides
config_args = []
# Extract LLM API credentials based on model type
# DashScope uses a fixed endpoint, OpenAI can have custom base_url
if isinstance(model, DashScopeChatModel):
llm_api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
llm_api_key = model.api_key
elif isinstance(model, OpenAIChatModel):
llm_api_base = str(getattr(model.client, "base_url", None))
llm_api_key = str(getattr(model.client, "api_key", None))
else:
raise ValueError(
f"model must be a DashScopeChatModel or "
f"OpenAIChatModel instance. "
f"Got {type(model).__name__} instead.",
)
# Extract model name and add to config if provided
llm_model_name = model.model_name
if llm_model_name:
config_args.append(f"llm.default.model_name={llm_model_name}")
# Extract embedding model API credentials based on type
# Similar to LLM, DashScope uses fixed endpoint,
# OpenAI can be customized
if isinstance(embedding_model, DashScopeTextEmbedding):
embedding_api_base = (
"https://dashscope.aliyuncs.com/compatible-mode/v1"
)
embedding_api_key = embedding_model.api_key
elif isinstance(embedding_model, OpenAITextEmbedding):
embedding_api_base = str(
getattr(embedding_model.client, "base_url", None),
)
embedding_api_key = str(
getattr(embedding_model.client, "api_key", None),
)
else:
raise ValueError(
"embedding_model must be a DashScopeTextEmbedding or "
"OpenAITextEmbedding instance. "
f"Got {type(embedding_model).__name__} instead.",
)
# Extract embedding model name and add to config if provided
embedding_model_name = embedding_model.model_name
if embedding_model_name:
config_args.append(
f"embedding_model.default.model_name={embedding_model_name}",
)
dimensions = embedding_model.dimensions
config_args.append(
f'embedding_model.default.params={{"dimensions": {dimensions}}}',
)
# Attempt to import and initialize ReMe
# If import fails, set app to None and issue a warning
# This allows the class to be instantiated even without
# reme_ai installed
try:
from reme_ai import ReMeApp
except ImportError as e:
raise ImportError(
"The 'reme_ai' library is required for ReMe-based "
"long-term memory. Please install it by `pip install reme-ai`,"
"and visit: https://github.com/modelscope/reMe for more "
"information.",
) from e
# Initialize ReMe with extracted configurations
self.app = ReMeApp(
*config_args, # Config overrides as positional args
llm_api_key=llm_api_key,
llm_api_base=llm_api_base,
embedding_api_key=embedding_api_key,
embedding_api_base=embedding_api_base,
# Optional custom config file
config_path=reme_config_path,
# Additional ReMe-specific configurations
**kwargs,
)
# Track if the app context is active (started via __aenter__)
self._app_started = False
async def __aenter__(self) -> "ReMeLongTermMemoryBase":
"""Async context manager entry point.
This method is called when entering an async context
(using 'async with'). It initializes the ReMe app context if
available, enabling memory operations within the context block.
Returns:
`ReMeLongTermMemoryBase`:
The memory instance itself, allowing it to be used in
the context.
Example:
.. code-block:: python
memory = ReMeToolLongTermMemory(
agent_name="my_agent",
model=model,
embedding_model=embedding
)
async with memory:
# Memory operations can be performed here
await memory.record_to_memory(
thinking="Recording tool usage",
content=[...]
)
"""
if self.app is not None:
await self.app.__aenter__()
self._app_started = True
return self
async def __aexit__(
self,
exc_type: Any = None,
exc_val: Any = None,
exc_tb: Any = None,
) -> None:
"""Async context manager exit point.
This method is called when exiting an async context (at the end
of 'async with' block or when an exception occurs). It properly
cleans up the ReMe app context and resources.
Args:
exc_type (`Any`):
The type of exception that occurred, if any. None if no
exception.
exc_val (`Any`):
The exception instance that occurred, if any. None if no
exception.
exc_tb (`Any`):
The traceback object for the exception, if any. None if
no exception.
.. note:: This method will gracefully handle the case where self.app
is None (reme_ai not installed) by skipping the cleanup but still
marking the app as stopped. It will also always set _app_started
to False, ensuring the memory state is properly reset.
Example:
.. code-block:: python
async with memory:
try:
# Memory operations
await memory.record_to_memory(...)
except Exception as e:
# __aexit__ will be called even if an exception occurs
print(f"Error: {e}")
# __aexit__ has been called and resources are cleaned up
"""
if self.app is not None:
await self.app.__aexit__(exc_type, exc_val, exc_tb)
self._app_started = False

View File

@@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
"""Personal memory implementation using ReMe library.
This module provides a personal memory implementation that integrates
with the ReMe library to provide persistent personal memory storage and
retrieval capabilities for AgentScope agents.
"""
from typing import Any
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
from ...._logging import logger
from ....message import Msg, TextBlock
from ....tool import ToolResponse
class ReMePersonalLongTermMemory(ReMeLongTermMemoryBase):
"""Personal memory implementation using ReMe library."""
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Record important user information to long-term memory.
Record important user information to long-term memory for future
reference.
Use this function to save user's personal information,
preferences, habits, and facts that you may need in future
conversations. This enables you to provide personalized and
contextually relevant responses.
When to record:
- User shares personal preferences (e.g., "I prefer homestays
when traveling")
- User mentions habits or routines (e.g., "I start work at 9 AM")
- User states likes/dislikes (e.g., "I enjoy drinking green tea")
- User provides personal facts (e.g., "I work as a software
engineer")
What to record: Be specific and structured. Include who, when,
where, what, why, and how when relevant.
Args:
thinking (`str`):
Your reasoning about why this information is worth
recording and how it might be useful later.
content (`list[str]`):
List of specific facts to remember. Each string should be
a clear, standalone piece of information. Examples:
["User prefers homestays in Hangzhou", "User likes
visiting West Lake in the morning"].
**kwargs (`Any`):
Additional keyword arguments for the recording operation.
Returns:
`ToolResponse`:
Confirmation message indicating successful memory
recording.
"""
logger.info(
"[ReMePersonalMemory] Entering record_to_memory - "
"thinking: %s, content: %s, kwargs: %s",
thinking,
content,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Prepare messages for personal memory recording
messages = []
# Add thinking as a user message if provided
if thinking:
messages.append(
{
"role": "user",
"content": thinking,
},
)
# Add content items as user messages
for item in content:
messages.append(
{
"role": "user",
"content": item,
},
)
# Add a simple assistant acknowledgment
messages.append(
{
"role": "assistant",
"content": (
"I understand and will remember this "
"information."
),
},
)
result = await self.app.async_execute(
name="summary_personal_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
},
],
**kwargs,
)
# Extract metadata about stored memories if available
metadata = result.get("metadata", {})
memory_list = metadata.get("memory_list", [])
if memory_list:
summary_text = (
f"Successfully recorded {len(memory_list)} "
f"memory/memories to personal memory."
)
else:
summary_text = "Memory recording completed."
return ToolResponse(
content=[
TextBlock(
type="text",
text=summary_text,
),
],
metadata={"result": result},
)
except Exception as e:
logger.exception("Error recording memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Search and retrieve relevant information from long-term memory.
.. note:: You should call this function BEFORE answering
questions about the user's preferences, past information, or
personal details. This ensures you provide accurate information
based on stored memories rather than guessing.
Use this when:
- User asks "what do I like?", "what are my preferences?",
"what do you know about me?"
- User asks about their past behaviors, habits, or stated
preferences
- User refers to information they shared in previous
conversations
- You need to personalize responses based on user's history
Args:
keywords (`list[str]`):
Keywords to search for in memory. Be specific and use
multiple keywords for better results. Examples:
["travel preferences", "Hangzhou"], ["work habits",
"morning routine"], ["food preferences", "tea"].
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 3.
**kwargs (`Any`):
Additional keyword arguments for the retrieval operation.
Returns:
`ToolResponse`:
Retrieved memories matching the keywords. If no memories
found, you'll receive a message indicating that.
"""
logger.info(
"[ReMePersonalMemory] Entering retrieve_from_memory - "
"keywords: %s, kwargs: %s",
keywords,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
results = []
# Search for each keyword
for keyword in keywords:
result = await self.app.async_execute(
name="retrieve_personal_memory",
workspace_id=self.workspace_id,
query=keyword,
top_k=limit,
**kwargs,
)
# Extract the answer from the result
answer = result.get("answer", "")
if answer:
results.append(f"Keyword '{keyword}':\n{answer}")
# Combine all results
if results:
combined_text = "\n\n".join(results)
else:
combined_text = "No memories found for the given keywords."
return ToolResponse(
content=[
TextBlock(
type="text",
text=combined_text,
),
],
)
except Exception as e:
logger.exception("Error retrieving memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving memory: {str(e)}",
),
],
)
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> None:
"""Record the content to the long-term memory.
This method converts AgentScope messages to ReMe's format and
records them using the personal memory flow.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory.
**kwargs (`Any`):
Additional keyword arguments for the mem0 recording.
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not msg_list:
return
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Convert AgentScope messages to ReMe format
messages = []
for msg in msg_list:
# Extract content as string
if isinstance(msg.content, str):
content_str = msg.content
elif isinstance(msg.content, list):
# Join content blocks into a single string
content_parts = []
for block in msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
content_str = "\n".join(content_parts)
else:
content_str = str(msg.content)
messages.append(
{
"role": msg.role,
"content": content_str,
},
)
await self.app.async_execute(
name="summary_personal_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
},
],
**kwargs,
)
except Exception as e:
# Log the error but don't raise to maintain compatibility
logger.exception("Error recording messages to memory: %s", str(e))
import warnings
warnings.warn(f"Error recording messages to memory: {str(e)}")
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve the content from the long-term memory.
Args:
msg (`Msg | list[Msg] | None`):
The message to search for in the memory, which should be
specific and concise, e.g. the person's name, the date, the
location, etc.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. If the
message is a list of messages, the limit applies to each
message. If the message is a single message, the limit is the
total number of memories to retrieve for that message. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved memory as a string.
"""
if msg is None:
return ""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Only use the last message's content for retrieval
last_msg = msg[-1]
query = ""
if isinstance(last_msg.content, str):
query = last_msg.content
elif isinstance(last_msg.content, list):
# Extract text from content blocks
content_parts = []
for block in last_msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
query = "\n".join(content_parts)
if not query:
return ""
# Retrieve using the query from the last message
result = await self.app.async_execute(
name="retrieve_personal_memory",
workspace_id=self.workspace_id,
query=query,
top_k=limit,
**kwargs,
)
return result.get("answer", "")
except Exception as e:
logger.exception("Error retrieving memory: %s", str(e))
import warnings
warnings.warn(f"Error retrieving memory: {str(e)}")
return ""

View File

@@ -0,0 +1,436 @@
# -*- coding: utf-8 -*-
"""Task memory implementation using ReMe library.
This module provides a task memory implementation that integrates
with the ReMe library to learn from execution trajectories and
retrieve relevant task experiences.
"""
from typing import Any
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
from ...._logging import logger
from ....message import Msg, TextBlock
from ....tool import ToolResponse
class ReMeTaskLongTermMemory(ReMeLongTermMemoryBase):
"""Task memory implementation using ReMe library.
Task memory learns from execution trajectories and provides
retrieval of relevant task experiences.
"""
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Record task execution experiences and learnings.
Record task execution experiences and learnings to long-term
memory.
Use this function to save valuable task-related knowledge that
can help with future similar tasks. This enables learning from
experience and improving over time.
When to record:
- After solving technical problems or completing tasks
- When discovering useful techniques or approaches
- After implementing solutions with specific steps
- When learning best practices or important lessons
What to record: Be detailed and actionable. Include:
- Task description and context
- Step-by-step execution details
- Specific techniques and methods used
- Results, outcomes, and effectiveness
- Lessons learned and considerations
Args:
thinking (`str`):
Your reasoning about why this task experience is valuable
and what makes it worth remembering for future reference.
content (`list[str]`):
List of specific task insights to remember. Each string
should be a clear, actionable piece of information.
Examples: ["Add indexes on WHERE clause columns to speed
up queries", "Use EXPLAIN ANALYZE to identify missing
indexes"].
**kwargs (`Any`):
Additional keyword arguments. Can include 'score' (float)
to indicate the quality/success of this approach
(default: 1.0).
Returns:
`ToolResponse`:
Confirmation message indicating successful memory
recording.
"""
logger.info(
"[ReMeTaskMemory] Entering record_to_memory - "
"thinking: %s, content: %s, kwargs: %s",
thinking,
content,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Prepare messages for task memory recording
messages = []
# Add thinking as a user message if provided
if thinking:
messages.append(
{
"role": "user",
"content": thinking,
},
)
# Add content items as user-assistant pairs
for item in content:
messages.append(
{
"role": "user",
"content": item,
},
)
# Add a simple assistant acknowledgment
messages.append(
{
"role": "assistant",
"content": "Task information recorded.",
},
)
result = await self.app.async_execute(
name="summary_task_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
"score": kwargs.pop("score", 1.0),
},
],
**kwargs,
)
# Extract metadata if available
summary_text = (
f"Successfully recorded {len(content)} task memory/memories."
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=summary_text,
),
],
metadata={"result": result},
)
except Exception as e:
logger.exception("Error recording task memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording task memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Search and retrieve relevant task experiences.
Search and retrieve relevant task experiences from long-term
memory.
IMPORTANT: You should call this function BEFORE attempting to
solve problems or answer technical questions. This ensures you
leverage experiences and proven solutions rather than
starting from scratch.
Use this when:
- Asked to solve a technical problem or implement a solution
- Asked for recommendations, best practices, or approaches
- Asked "what do you know about...?" or "have you seen this
before?"
- Dealing with tasks that may be similar to experiences
- Need to recall specific techniques or methods
Benefits of retrieving first:
- Learn from past successes and mistakes
- Provide more accurate, battle-tested solutions
- Avoid reinventing the wheel
- Give consistent, informed recommendations
Args:
keywords (`list[str]`):
Keywords describing the task or problem domain. Be
specific and use technical terms. Examples:
["database optimization", "slow queries"], ["API design",
"rate limiting"], ["code refactoring", "Python"].
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments. Can include 'top_k' (int)
to specify number of experiences to retrieve
(default: 3).
Returns:
`ToolResponse`:
Retrieved task experiences and learnings. If no relevant
experiences found, you'll receive a message indicating
that.
"""
logger.info(
"[ReMeTaskMemory] Entering retrieve_from_memory - "
"keywords: %s, kwargs: %s",
keywords,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
results = []
# Search for each keyword
for keyword in keywords:
result = await self.app.async_execute(
name="retrieve_task_memory",
workspace_id=self.workspace_id,
query=keyword,
top_k=limit,
**kwargs,
)
# Extract the answer from the result
answer = result.get("answer", "")
if answer:
results.append(f"Keyword '{keyword}':\n{answer}")
# Combine all results
if results:
combined_text = "\n\n".join(results)
else:
combined_text = (
"No task experiences found for the given keywords."
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=combined_text,
),
],
)
except Exception as e:
logger.exception("Error retrieving task memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving task memory: {str(e)}",
),
],
)
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> None:
"""Record the content to the task memory.
This method converts AgentScope messages to ReMe's format and
records them as a task execution trajectory.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory.
**kwargs (`Any`):
Additional keyword arguments for the recording.
Can include 'score' (float) for trajectory scoring
(default: 1.0).
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not msg_list:
return
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Convert AgentScope messages to ReMe format
messages = []
for msg in msg_list:
# Extract content as string
if isinstance(msg.content, str):
content_str = msg.content
elif isinstance(msg.content, list):
# Join content blocks into a single string
content_parts = []
for block in msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
content_str = "\n".join(content_parts)
else:
content_str = str(msg.content)
messages.append(
{
"role": msg.role,
"content": content_str,
},
)
# Extract score from kwargs if provided, default to 1.0
score = kwargs.pop("score", 1.0)
await self.app.async_execute(
name="summary_task_memory",
workspace_id=self.workspace_id,
trajectories=[
{
"messages": messages,
"score": score,
},
],
**kwargs,
)
except Exception as e:
# Log the error but don't raise to maintain compatibility
logger.exception(
"Error recording messages to task memory: %s",
str(e),
)
import warnings
warnings.warn(
f"Error recording messages to task memory: {str(e)}",
)
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve relevant task experiences from memory.
Args:
msg (`Msg | list[Msg] | None`):
The message to search for relevant task experiences.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. If the
message is a list of messages, the limit applies to each
message. If the message is a single message, the limit is the
total number of memories to retrieve for that message. Defaults
to 3.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved task experiences as a string.
"""
if msg is None:
return ""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Only use the last message's content for retrieval
last_msg = msg[-1]
query = ""
if isinstance(last_msg.content, str):
query = last_msg.content
elif isinstance(last_msg.content, list):
# Extract text from content blocks
content_parts = []
for block in last_msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
elif isinstance(block, dict) and "thinking" in block:
content_parts.append(block["thinking"])
query = "\n".join(content_parts)
if not query:
return ""
# Retrieve using the query from the last message
result = await self.app.async_execute(
name="retrieve_task_memory",
workspace_id=self.workspace_id,
query=query,
top_k=limit,
**kwargs,
)
return result.get("answer", "")
except Exception as e:
logger.exception("Error retrieving task memory: %s", str(e))
import warnings
warnings.warn(f"Error retrieving task memory: {str(e)}")
return ""

View File

@@ -0,0 +1,545 @@
# -*- coding: utf-8 -*-
"""Tool memory implementation using ReMe library.
This module provides a tool memory implementation that integrates
with the ReMe library to record tool execution results and retrieve
tool usage guidelines.
"""
from typing import Any
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
from ...._logging import logger
from ....message import Msg, TextBlock
from ....tool import ToolResponse
class ReMeToolLongTermMemory(ReMeLongTermMemoryBase):
"""Tool memory implementation using ReMe library.
Tool memory records tool execution results and generates usage
guidelines from the execution history.
"""
async def record_to_memory(
self,
thinking: str,
content: list[str],
**kwargs: Any,
) -> ToolResponse:
"""Record tool execution results to build tool usage patterns.
Record tool execution results to build a knowledge base of tool
usage patterns.
Use this function after successfully using tools to capture
execution details, results, and performance metrics. Over time,
this builds comprehensive usage guidelines and best practices
for each tool.
When to record:
- After successfully executing any tool
- After tool failures (to learn what doesn't work)
- When discovering effective parameter combinations
- After noteworthy tool usage patterns
What to record: Each tool execution should include complete
execution details.
Args:
thinking (`str`):
Your reasoning about why this tool execution is worth
recording. Mention what worked well, what could be
improved, or lessons learned.
content (`list[str]`):
List of JSON strings, each representing a tool execution.
Each JSON must have these fields:
- create_time: Timestamp in format "YYYY-MM-DD HH:MM:SS"
- tool_name: Name of the tool executed
- input: Input parameters as a dict
- output: Tool's output as a string
- token_cost: Token cost (integer)
- success: Whether execution succeeded (boolean)
- time_cost: Execution time in seconds (float)
Example: '{"create_time": "2024-01-01 10:00:00",
"tool_name": "search", "input": {"query": "Python"},
"output": "Found 10 results", "token_cost": 100,
"success": true, "time_cost": 1.2}'
**kwargs (`Any`):
Additional keyword arguments for the recording operation.
Returns:
`ToolResponse`:
Confirmation message with number of executions recorded
and guidelines generated.
"""
logger.info(
"[ReMeToolMemory] Entering record_to_memory - "
"thinking: %s, content: %s, kwargs: %s",
thinking,
content,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
import json
# Parse each content item as a tool_call_result
tool_call_results = []
tool_names_set = set()
for item in content:
try:
# Parse JSON string to dict
tool_call_result = json.loads(item)
tool_call_results.append(tool_call_result)
# Track tool names for summary
if "tool_name" in tool_call_result:
tool_names_set.add(tool_call_result["tool_name"])
except json.JSONDecodeError as e:
# Skip invalid JSON items
import warnings
warnings.warn(
f"Failed to parse tool call result JSON: {item}. "
f"Error: {str(e)}",
)
continue
if not tool_call_results:
return ToolResponse(
content=[
TextBlock(
type="text",
text="No valid tool call results to record.",
),
],
)
# First, add the tool call results
await self.app.async_execute(
name="add_tool_call_result",
workspace_id=self.workspace_id,
tool_call_results=tool_call_results,
**kwargs,
)
# Then, summarize the tool memory for the affected tools
if tool_names_set:
tool_names_list = list(tool_names_set)
await self.app.async_execute(
name="summary_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names_list,
**kwargs,
)
num_results = len(tool_call_results)
summary_text = (
f"Successfully recorded {num_results} tool execution "
f"result{'s' if num_results > 1 else ''} and generated "
f"usage guidelines."
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=summary_text,
),
],
)
except Exception as e:
logger.exception("Error recording tool memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error recording tool memory: {str(e)}",
),
],
)
async def retrieve_from_memory(
self,
keywords: list[str],
limit: int = 5,
**kwargs: Any,
) -> ToolResponse:
"""Retrieve usage guidelines and best practices for tools.
Retrieve usage guidelines and best practices for specific tools.
.. note:: You should call this function BEFORE using a tool,
especially if you're uncertain about its proper usage or want to
follow established best practices. This retrieves synthesized
guidelines based on past tool executions.
Use this when:
- About to use a tool and want to know the best practices
- Uncertain about tool parameters or usage patterns
- Want to learn from past successful/failed tool executions
- User asks "how should I use this tool?" or "what's the best
way to..."
- Need to understand tool performance characteristics or
limitations
Benefits of retrieving first:
- Learn from accumulated tool usage experience
- Avoid common mistakes and pitfalls
- Use optimal parameter combinations
- Understand tool performance and cost characteristics
- Follow established best practices
Args:
keywords (`list[str]`):
List of tool names to retrieve guidelines for. Use the
exact tool names. Examples: ["search"],
["database_query", "cache_get"], ["api_call"].
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for each keyword. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments for the retrieval operation.
Returns:
`ToolResponse`:
Retrieved usage guidelines and best practices for the
specified tools. If no guidelines exist yet, you'll
receive a message indicating that.
"""
logger.info(
"[ReMeToolMemory] Entering retrieve_from_memory - "
"keywords: %s, kwargs: %s",
keywords,
kwargs,
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Join all tool names with comma
tool_names = ",".join(keywords)
# Retrieve tool guidelines for all tools at once
result = await self.app.async_execute(
name="retrieve_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names,
top_k=limit,
**kwargs,
)
# Extract the answer from the result
answer = result.get("answer", "")
if answer:
combined_text = answer
else:
combined_text = f"No tool guidelines found for: {tool_names}"
return ToolResponse(
content=[
TextBlock(
type="text",
text=combined_text,
),
],
)
except Exception as e:
logger.exception("Error retrieving tool memory: %s", str(e))
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error retrieving tool memory: {str(e)}",
),
],
)
def _extract_content_from_messages(self, msg_list: list[Msg]) -> list[str]:
"""Extract content strings from messages.
Args:
msg_list (`list[Msg]`):
List of messages to extract content from.
Returns:
`list[str]`:
List of extracted content strings.
"""
content_list = []
for msg in msg_list:
if isinstance(msg.content, str):
content_list.append(msg.content)
elif isinstance(msg.content, list):
content_list.extend(
self._extract_text_from_blocks(msg.content),
)
return content_list
def _extract_text_from_blocks(self, blocks: list) -> list[str]:
"""Extract text from content blocks.
Args:
blocks (`list`):
List of content blocks.
Returns:
`list[str]`:
List of extracted text strings.
"""
texts = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "text":
texts.append(block.get("text", ""))
elif isinstance(block, str):
texts.append(block)
return texts
def _parse_tool_call_results(
self,
content_list: list[str],
) -> tuple[list[dict], set[str]]:
"""Parse JSON content strings into tool call results.
Args:
content_list (`list[str]`):
List of JSON strings to parse.
Returns:
`tuple[list[dict], set[str]]`:
Tuple of (tool_call_results, tool_names_set).
"""
import json
import warnings
tool_call_results = []
tool_names_set = set()
for item in content_list:
try:
tool_call_result = json.loads(item)
tool_call_results.append(tool_call_result)
if "tool_name" in tool_call_result:
tool_names_set.add(tool_call_result["tool_name"])
except json.JSONDecodeError as e:
warnings.warn(
f"Failed to parse tool call result JSON: {item}. "
f"Error: {str(e)}",
)
return tool_call_results, tool_names_set
async def record(
self,
msgs: list[Msg | None],
**kwargs: Any,
) -> None:
"""Record the content to the tool memory.
This method extracts content from messages and treats them as
JSON strings representing tool_call_results, similar to
record_to_memory.
Args:
msgs (`list[Msg | None]`):
The messages to record to memory. Each message's content
should be a JSON string or list of JSON strings
representing tool_call_results.
**kwargs (`Any`):
Additional keyword arguments for the recording.
"""
if isinstance(msgs, Msg):
msgs = [msgs]
# Filter out None
msg_list = [_ for _ in msgs if _]
if not msg_list:
return
if not all(isinstance(_, Msg) for _ in msg_list):
raise TypeError(
"The input messages must be a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Extract content from messages and parse as tool_call_results
content_list = self._extract_content_from_messages(msg_list)
if not content_list:
return
# Parse each content item as a tool_call_result
tool_call_results, tool_names_set = self._parse_tool_call_results(
content_list,
)
if not tool_call_results:
return
# First, add the tool call results
await self.app.async_execute(
name="add_tool_call_result",
workspace_id=self.workspace_id,
tool_call_results=tool_call_results,
**kwargs,
)
# Then, summarize the tool memory for the affected tools
if tool_names_set:
tool_names_list = list(tool_names_set)
await self.app.async_execute(
name="summary_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names_list,
**kwargs,
)
except Exception as e:
# Log the error but don't raise to maintain compatibility
logger.exception(
"Error recording tool messages to memory: %s",
str(e),
)
import warnings
warnings.warn(
f"Error recording tool messages to memory: {str(e)}",
)
def _extract_tool_names_from_message(self, msg: Msg) -> str:
"""Extract tool names from a message.
Args:
msg (`Msg`):
Message to extract tool names from.
Returns:
`str`:
Extracted tool names as a string.
"""
if isinstance(msg.content, str):
return msg.content
if isinstance(msg.content, list):
content_parts = []
for block in msg.content:
if isinstance(block, dict) and "text" in block:
content_parts.append(block["text"])
return " ".join(content_parts)
return ""
def _format_retrieve_result(self, result: Any) -> str:
"""Format the retrieve result into a string.
Args:
result (`Any`):
Result from the retrieve operation.
Returns:
`str`:
Formatted result string.
"""
if isinstance(result, dict) and "answer" in result:
return result["answer"]
if isinstance(result, str):
return result
return str(result)
async def retrieve(
self,
msg: Msg | list[Msg] | None,
limit: int = 5,
**kwargs: Any,
) -> str:
"""Retrieve tool guidelines from memory.
Retrieve tool guidelines from memory based on message content.
Args:
msg (`Msg | list[Msg] | None`):
The message containing tool names or queries to
retrieve guidelines for.
limit (`int`, optional):
The maximum number of memories to retrieve per search, i.e.,
the number of memories to retrieve for the message. If the
message is a list of messages, the limit applies to each
message. If the message is a single message, the limit is the
total number of memories to retrieve for that message. Defaults
to 5.
**kwargs (`Any`):
Additional keyword arguments.
Returns:
`str`:
The retrieved tool guidelines as a string.
"""
if msg is None:
return ""
if isinstance(msg, Msg):
msg = [msg]
if not isinstance(msg, list) or not all(
isinstance(_, Msg) for _ in msg
):
raise TypeError(
"The input message must be a Msg or a list of Msg objects.",
)
if not self._app_started:
raise RuntimeError(
"ReMeApp context not started. "
"Please use 'async with' to initialize the app.",
)
try:
# Extract tool names from the last message
last_msg = msg[-1]
tool_names = self._extract_tool_names_from_message(last_msg)
if not tool_names:
return ""
# Retrieve tool guidelines
result = await self.app.async_execute(
name="retrieve_tool_memory",
workspace_id=self.workspace_id,
tool_names=tool_names,
top_k=limit,
**kwargs,
)
return self._format_retrieve_result(result)
except Exception as e:
logger.exception("Error retrieving tool guidelines: %s", str(e))
import warnings
warnings.warn(f"Error retrieving tool guidelines: {str(e)}")
return ""

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The working memory module in AgentScope, which provides various memory
storage implementations. In AgentScope, such module is responsible for
storing and managing the short-term memory with specific marks."""
from ._base import MemoryBase
from ._in_memory_memory import InMemoryMemory
from ._redis_memory import RedisMemory
from ._sqlalchemy_memory import AsyncSQLAlchemyMemory
__all__ = [
"MemoryBase",
"InMemoryMemory",
"RedisMemory",
"AsyncSQLAlchemyMemory",
]

View File

@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
"""The memory base class."""
from abc import abstractmethod
from typing import Any
from ...message import Msg
from ...module import StateModule
class MemoryBase(StateModule):
"""The base class for memory in agentscope."""
def __init__(self) -> None:
"""Initialize the memory base."""
super().__init__()
self._compressed_summary: str = ""
self.register_state("_compressed_summary")
async def update_compressed_summary(self, summary: str) -> None:
"""Update the compressed summary of the memory.
Args:
summary (`str`):
The new compressed summary.
"""
self._compressed_summary = summary
@abstractmethod
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
**kwargs: Any,
) -> None:
"""Add message(s) into the memory storage with the given mark
(if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
"""
@abstractmethod
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
async def delete_by_mark(
self,
mark: str | list[str],
*args: Any,
**kwargs: Any,
) -> int:
"""Remove messages from the memory by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Raises:
`TypeError`:
If the provided mark is not a string or a list of strings.
Returns:
`int`:
The number of messages removed.
"""
raise NotImplementedError(
"The delete_by_mark method is not implemented in "
f"{self.__class__.__name__} class.",
)
@abstractmethod
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear the memory content."""
@abstractmethod
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
raise NotImplementedError(
"The update_messages_mark method is not implemented in "
f"{self.__class__.__name__} class.",
)

View File

@@ -0,0 +1,305 @@
# -*- coding: utf-8 -*-
"""The in-memory storage module for memory storage."""
from copy import deepcopy
from typing import Any
from ...message import Msg
from ._base import MemoryBase
class InMemoryMemory(MemoryBase):
"""The in-memory implementation of memory storage."""
def __init__(self) -> None:
"""Initialize the in-memory storage."""
super().__init__()
# Use a list of tuples to store messages along with their marks
self.content: list[tuple[Msg, list[str]]] = []
# Register the state for serialization
self.register_state("content")
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Raises:
`TypeError`:
If the provided mark is not a string or None.
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if not (mark is None or isinstance(mark, str)):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if not (exclude_mark is None or isinstance(exclude_mark, str)):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
# Filter messages based on mark
filtered_content = [
(msg, marks)
for msg, marks in self.content
if mark is None or mark in marks
]
# Further filter messages based on exclude_mark
if exclude_mark is not None:
filtered_content = [
(msg, marks)
for msg, marks in filtered_content
if exclude_mark not in marks
]
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*[msg for msg, _ in filtered_content],
]
return [msg for msg, _ in filtered_content]
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
allow_duplicates: bool = False,
**kwargs: Any,
) -> None:
"""Add message(s) into the memory storage with the given mark
(if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
allow_duplicates (`bool`, defaults to `False`):
Whether to allow duplicate messages in the storage.
"""
if memories is None:
return
if isinstance(memories, Msg):
memories = [memories]
if marks is None:
marks = []
elif isinstance(marks, str):
marks = [marks]
elif not isinstance(marks, list) or not all(
isinstance(m, str) for m in marks
):
raise TypeError(
f"The mark should be a string, a list of strings, or None, "
f"but got {type(marks)}.",
)
if not allow_duplicates:
existing_ids = {msg.id for msg, _ in self.content}
memories = [msg for msg in memories if msg.id not in existing_ids]
for msg in memories:
self.content.append((deepcopy(msg), deepcopy(marks)))
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
initial_size = len(self.content)
self.content = [
(msg, marks)
for msg, marks in self.content
if msg.id not in msg_ids
]
return initial_size - len(self.content)
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the memory by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Raises:
`TypeError`:
If the provided mark is not a string or a list of strings.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
if isinstance(mark, list) and not all(
isinstance(m, str) for m in mark
):
raise TypeError(
f"The mark should be a string or a list of strings, "
f"but got {type(mark)} with elements of types "
f"{[type(m) for m in mark]}.",
)
initial_size = len(self.content)
for m in mark:
self.content = [
(msg, marks) for msg, marks in self.content if m not in marks
]
return initial_size - len(self.content)
async def clear(self) -> None:
"""Clear all messages from the storage."""
self.content.clear()
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
return len(self.content)
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
updated_count = 0
for idx, (msg, marks) in enumerate(self.content):
# If msg_ids is provided, skip messages not in the list
if msg_ids is not None and msg.id not in msg_ids:
continue
# If old_mark is provided, skip messages that do not have the old
# mark
if old_mark is not None and old_mark not in marks:
continue
# If new_mark is None, remove the old_mark
if new_mark is None:
if old_mark in marks:
marks.remove(old_mark)
updated_count += 1
else:
# If new_mark is provided, add or replace the old_mark
if old_mark is not None and old_mark in marks:
marks.remove(old_mark)
if new_mark not in marks:
marks.append(new_mark)
updated_count += 1
self.content[idx] = (msg, marks)
return updated_count
def state_dict(self) -> dict:
"""Get the state dictionary for serialization."""
return {
**super().state_dict(),
"content": [[msg.to_dict(), marks] for msg, marks in self.content],
}
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
"""Load the state dictionary for deserialization."""
if strict and "content" not in state_dict:
raise KeyError(
"The state_dict does not contain 'content' "
"keys required for InMemoryMemory.",
)
self._compressed_summary = state_dict.get("_compressed_summary", "")
self.content = []
for item in state_dict.get("content", []):
if isinstance(item, (tuple, list)) and len(item) == 2:
msg_dict, marks = item
msg = Msg.from_dict(msg_dict)
self.content.append((msg, marks))
elif isinstance(item, dict):
# For compatibility with older versions
msg = Msg.from_dict(item)
self.content.append((msg, []))
else:
raise ValueError(
"Invalid item format in state_dict for InMemoryMemory.",
)

View File

@@ -0,0 +1,827 @@
# -*- coding: utf-8 -*-
"""The redis based memory storage implementation."""
import json
from typing import Any, TYPE_CHECKING
from ._base import MemoryBase
from ...message import Msg
if TYPE_CHECKING:
from redis.asyncio import ConnectionPool, Redis
else:
ConnectionPool = Any
Redis = Any
class RedisMemory(MemoryBase):
"""Redis memory storage implementation, which supports session and user
context.
.. note:: All the operations in this class are within a specific session
and user context, identified by `session_id` and `user_id`. Cross-session
or cross-user operations are not supported. For example, the
`remove_messages` method will only remove messages that belong to the
specified `session_id` and `user_id`.
.. note:: All Redis keys used by this class will be prefixed by `prefix`
(if provided) to support multi-tenant / multi-app isolation.
**Mark Index Storage:**
This class maintains a `marks_index` (Redis Set) to efficiently track all
mark names within a session. When a mark is created via `add_mark()`, the
mark name is added to this set. This allows quick retrieval of all marks
without scanning all Redis keys. The marks_index key pattern is:
``user_id:{user_id}:session:{session_id}:marks_index``
Each individual mark stores its associated message IDs in a separate Redis
List with the key pattern:
``user_id:{user_id}:session:{session_id}:mark:{mark}``
"""
SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
"""Redis key pattern (without prefix) for storing message IDs (ordered) for
a specific session.
"""
SESSION_PATTERN = "user_id:{user_id}:session:{session_id}:*"
"""Redis key pattern (without prefix) for scanning all keys belong to
a specific user and session."""
MARK_KEY = "user_id:{user_id}:session:{session_id}:mark:{mark}"
"""Redis key pattern (without prefix) for storing message IDs that belong
to a specific mark.
"""
MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
"""Redis key pattern (without prefix) for storing message payload data."""
MARKS_INDEX_KEY = "user_id:{user_id}:session:{session_id}:marks_index"
"""Redis key pattern (without prefix) for storing all mark names as a set.
This is used to avoid scanning all keys to find marks.
"""
def __init__(
self,
session_id: str = "default_session",
user_id: str = "default_user",
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str | None = None,
connection_pool: ConnectionPool | None = None,
key_prefix: str = "",
key_ttl: int | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Redis based storage by connecting to the Redis
server. You can provide either the connection parameters or an
existing connection pool.
Args:
session_id (`str`, default to `"default_session"`):
The session ID for the storage.
user_id (`str`, default to `"default_user"`):
The user ID for the storage.
host (`str`, default to `"localhost"`):
The Redis server host.
port (`int`, default to `6379`):
The Redis server port.
db (`int`, default to `0`):
The Redis database index.
password (`str | None`, optional):
The password for the Redis server, if required.
connection_pool (`ConnectionPool | None`, optional):
An optional Redis connection pool. If provided, it will be used
instead of creating a new connection.
key_prefix (`str`, default to `""`):
Optional Redis key prefix prepended to every key generated by
this storage. Useful for isolating keys across
apps/environments (e.g. `"prod"`, `"staging"`, `"myapp"`).
key_ttl (`int | None`, default to `None`):
The expired time in seconds for each key. If provided, the
expiration will be refreshed on every access (sliding TTL). If
`None`, the keys will not expire. **Note** the ttl will update
all session related keys, so it's not recommended to set for
large sessions.
**kwargs (`Any`):
Additional keyword arguments to pass to the Redis client.
"""
try:
import redis.asyncio as redis
except ImportError as e:
raise ImportError(
"The 'redis' package is required for RedisStorage. "
"Please install it via 'pip install redis[async]'.",
) from e
super().__init__()
self.session_id = session_id
self.user_id = user_id
self.key_prefix = key_prefix or ""
self.key_ttl = key_ttl
self._client = redis.Redis(
host=host,
port=port,
db=db,
password=password,
connection_pool=connection_pool,
decode_responses=True,
**kwargs,
)
def get_client(self) -> Redis:
"""Get the underlying Redis client.
Returns:
`Redis`:
The Redis client instance.
"""
return self._client
def _decode_if_bytes(self, data: Any) -> Any:
"""Helper method to decode bytes to str if needed.
Args:
data (`Any`):
The data to decode, which may be bytes, bytearray, or str.
Returns:
`Any`:
The decoded string if input was bytes/bytearray, otherwise
the original data.
"""
if isinstance(data, (bytes, bytearray)):
return data.decode("utf-8")
return data
def _decode_list(self, data_list: list) -> list:
"""Helper method to decode a list of potential bytes.
Args:
data_list (`list`):
A list that may contain bytes, bytearray, or str elements.
Returns:
`list`:
A list with all bytes/bytearray elements decoded to str.
"""
return [self._decode_if_bytes(item) for item in data_list]
def _get_session_key(self) -> str:
"""Get the Redis key for the current session.
Returns:
`str`:
The Redis key for storing messages in the current session.
"""
return self.key_prefix + self.SESSION_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _get_session_pattern(self) -> str:
"""Get the Redis key pattern for all keys in the current session.
Returns:
`str`:
The Redis key pattern for all keys in the current session.
"""
return self.key_prefix + self.SESSION_PATTERN.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _get_mark_key(self, mark: str) -> str:
"""Get the Redis key for a specific mark.
Args:
mark (`str`):
The mark name.
Returns:
`str`:
The Redis key for storing message IDs with the given mark.
"""
return self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark=mark,
)
def _get_mark_pattern(self) -> str:
"""Get the Redis key pattern for all marks in the current session.
Returns:
`str`:
The Redis key pattern for all mark keys.
"""
return self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark="*",
)
def _get_marks_index_key(self) -> str:
"""Get the Redis key for the marks index set.
Returns:
`str`:
The Redis key for storing all mark names as a set.
"""
return self.key_prefix + self.MARKS_INDEX_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
)
def _extract_mark_from_key(self, mark_key: str) -> str:
"""Extract the mark name from a full mark key.
Args:
mark_key (`str`):
The full Redis key for a mark.
Returns:
`str`:
The mark name extracted from the key.
"""
# Remove the prefix and the base pattern to get the mark name
# Example: "prefix:user_id:xxx:session:yyy:mark:my_mark" -> "my_mark"
prefix_pattern = self.key_prefix + self.MARK_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
mark="",
)
return mark_key.replace(prefix_pattern, "")
def _get_message_key(self, msg_id: str) -> str:
"""Get the Redis key for a specific message.
Args:
msg_id (`str`):
The message ID.
Returns:
`str`:
The Redis key for storing the message data.
"""
return self.key_prefix + self.MESSAGE_KEY.format(
user_id=self.user_id,
session_id=self.session_id,
msg_id=msg_id,
)
async def _refresh_session_ttl(
self,
pipe: Any | None = None,
) -> None:
"""Refresh the TTL for the session keys (if `key_ttl` is set).
Args:
pipe (`Any | None`, optional):
An optional Redis pipeline to use. If `None`, a new pipeline
will be created and executed immediately. If provided, the
expired commands will be added to the pipeline without
executing it.
"""
if self.key_ttl is None:
return
# Create a new pipeline if not provided
should_execute = pipe is None
if pipe is None:
pipe = self._client.pipeline()
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_session_pattern(),
count=100,
)
# Decode keys if they are bytes
keys = self._decode_list(keys)
for key in keys:
await pipe.expire(key, self.key_ttl)
if cursor == 0:
break
if should_execute:
await pipe.execute()
async def _scan_and_migrate_marks(self) -> list[str]:
"""Scan all mark keys and migrate them to the marks index.
This method is only called once for old data that doesn't have
a marks index yet. After migration, the marks index will be
maintained automatically.
Returns:
`list[str]`:
The list of all mark keys found.
"""
mark_keys = []
cursor = 0
while True:
cursor, keys = await self._client.scan(
cursor,
match=self._get_mark_pattern(),
count=50,
)
keys = self._decode_list(keys)
mark_keys.extend(keys)
if cursor == 0:
break
# Build the marks index
if mark_keys:
pipe = self._client.pipeline()
for mark_key in mark_keys:
mark = self._extract_mark_from_key(mark_key)
await pipe.sadd(self._get_marks_index_key(), mark)
await pipe.execute()
return mark_keys
async def _get_all_mark_keys(self) -> list[str]:
"""Get all mark keys, compatible with both old and new data.
For new data (with marks index), this method uses the index directly.
For old data (without marks index), this method scans once and
migrates to the new structure.
Returns:
`list[str]`:
The list of all mark keys.
"""
marks_index_key = self._get_marks_index_key()
# Try to read from the index first
marks = await self._client.smembers(marks_index_key)
if marks:
# Index exists, use it
marks = self._decode_list(list(marks))
return [self._get_mark_key(mark) for mark in marks]
# Index doesn't exist, check if this is a new session
session_exists = await self._client.exists(self._get_session_key())
if not session_exists:
# New session, no data at all, return empty
return []
# Old session without index, need to scan and migrate (only once)
mark_keys = await self._scan_and_migrate_marks()
return mark_keys
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if not (mark is None or isinstance(mark, str)):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if not (exclude_mark is None or isinstance(exclude_mark, str)):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
if mark is None:
# Obtain the message IDs from the session list
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
else:
# Obtain the message IDs from the mark list
msg_ids = await self._client.lrange(
self._get_mark_key(mark),
0,
-1,
)
msg_ids = self._decode_list(msg_ids)
# Exclude messages by exclude_mark
if exclude_mark:
exclude_msg_ids = await self._client.lrange(
self._get_mark_key(exclude_mark),
0,
-1,
)
exclude_msg_ids = self._decode_list(exclude_msg_ids)
msg_ids = [_ for _ in msg_ids if _ not in exclude_msg_ids]
# Use mget for batch retrieval to avoid N+1 queries
messages: list[Msg] = []
if msg_ids:
msg_keys = [self._get_message_key(msg_id) for msg_id in msg_ids]
msg_data_list = await self._client.mget(msg_keys)
for msg_data in msg_data_list:
if msg_data is not None:
# Decode if bytes
msg_data = self._decode_if_bytes(msg_data)
msg_dict = json.loads(msg_data)
messages.append(Msg.from_dict(msg_dict))
# Refresh TTLs
await self._refresh_session_ttl()
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*messages,
]
return messages
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
skip_duplicated: bool = True,
**kwargs: Any,
) -> None:
"""Add message into the storage with the given mark (if provided).
Args:
memories (`Msg | list[Msg]`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
skip_duplicated (`bool`, defaults to `True`):
If `True`, skip messages with duplicate IDs that already exist
in the storage. If `False`, allow duplicate message IDs to be
added to the session list (though the message data will be
overwritten).
"""
if memories is None:
return
if isinstance(memories, Msg):
memories = [memories]
# Normalize marks to a list
if marks is None:
mark_list = []
elif isinstance(marks, str):
mark_list = [marks]
else:
mark_list = marks
# Filter out existing messages if skip_duplicated is True
messages_to_add = memories
if skip_duplicated:
# Get all existing message IDs in the current session
existing_msg_ids = await self._client.lrange(
self._get_session_key(),
0,
-1,
)
existing_msg_ids = self._decode_list(existing_msg_ids)
existing_msg_ids_set = set(existing_msg_ids)
# Filter out messages that already exist
messages_to_add = [
m for m in memories if m.id not in existing_msg_ids_set
]
# If all messages are duplicates, return early
if not messages_to_add:
return
# Use pipeline for atomic operations
pipe = self._client.pipeline()
# Push message ids into the session list
if messages_to_add:
await pipe.rpush(
self._get_session_key(),
*[m.id for m in messages_to_add],
)
# Store message data and marks
for m in messages_to_add:
# Record the message data
await pipe.set(
self._get_message_key(m.id),
json.dumps(m.to_dict(), ensure_ascii=False),
)
# Record the marks if provided
for mark in mark_list:
await pipe.rpush(self._get_mark_key(mark), m.id)
# Maintain the marks index
await pipe.sadd(self._get_marks_index_key(), mark)
# Refresh TTLs
await self._refresh_session_ttl(pipe=pipe)
await pipe.execute()
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
if not msg_ids:
return 0
# Get all mark keys using the new method (compatible with old data)
mark_keys = await self._get_all_mark_keys()
pipe = self._client.pipeline()
for msg_id in msg_ids:
# Remove from the session (0 means remove all occurrences)
await pipe.lrem(self._get_session_key(), 0, msg_id)
# Remove the message data
await pipe.delete(self._get_message_key(msg_id))
# Remove from all marks
for mark_key in mark_keys:
await pipe.lrem(mark_key, 0, msg_id)
# Refresh TTLs
await self._refresh_session_ttl(pipe=pipe)
results = await pipe.execute()
# Count actual deletions from lrem results (every 3rd result
# starting from 0)
removed_count = sum(
1
for i in range(
0,
len(msg_ids) * (2 + len(mark_keys)),
2 + len(mark_keys),
)
if results[i] > 0
)
return removed_count
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the storage by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
total_removed = 0
for m in mark:
mark_key = self._get_mark_key(m)
msg_ids = await self._client.lrange(mark_key, 0, -1)
msg_ids = self._decode_list(msg_ids)
if not msg_ids:
continue
# Remove messages by IDs
removed_count = await self.delete(
msg_ids,
)
total_removed += removed_count
# Delete the mark list
await self._client.delete(mark_key)
# Remove from the marks index
await self._client.srem(self._get_marks_index_key(), m)
# Refresh TTLs
await self._refresh_session_ttl()
return total_removed
async def clear(self) -> None:
"""Clear all messages belong to this session from the storage."""
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
msg_ids = self._decode_list(msg_ids)
# Get all mark keys using the new method (compatible with old data)
mark_keys = await self._get_all_mark_keys()
pipe = self._client.pipeline()
for msg_id in msg_ids:
# Remove the message data
await pipe.delete(self._get_message_key(msg_id))
# Delete the session list
await pipe.delete(self._get_session_key())
# Delete all mark lists
for mark_key in mark_keys:
await pipe.delete(mark_key)
# Delete the marks index
await pipe.delete(self._get_marks_index_key())
await pipe.execute()
async def size(self) -> int:
"""Get the number of messages in the storage.
Returns:
`int`:
The number of messages in the storage.
"""
size = await self._client.llen(self._get_session_key())
await self._refresh_session_ttl()
return size
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
# Determine which message IDs to update
# Get source key based on old_mark
source_key = (
self._get_mark_key(old_mark)
if old_mark is not None
else self._get_session_key()
)
mark_msg_ids = await self._client.lrange(source_key, 0, -1)
mark_msg_ids = self._decode_list(mark_msg_ids)
# Check if we're removing all messages from old_mark
removing_all_from_old_mark = old_mark is not None and (
msg_ids is None or all(mid in set(msg_ids) for mid in mark_msg_ids)
)
# Filter by msg_ids if provided
if msg_ids is not None:
msg_ids_set = set(msg_ids)
mark_msg_ids = [mid for mid in mark_msg_ids if mid in msg_ids_set]
if not mark_msg_ids:
return 0
# Get existing IDs in new_mark list once (if needed)
existing_ids_set = set()
new_mark_key = None
if new_mark is not None:
new_mark_key = self._get_mark_key(new_mark)
existing_ids = await self._client.lrange(new_mark_key, 0, -1)
existing_ids = self._decode_list(existing_ids)
existing_ids_set = set(existing_ids)
# Use pipeline for batch operations
pipe = self._client.pipeline()
updated_count = 0
for msg_id in mark_msg_ids:
# Remove from old_mark list if applicable
if old_mark is not None:
await pipe.lrem(
self._get_mark_key(old_mark),
0,
msg_id,
)
# Add to new_mark list only if not already present
if new_mark is not None and msg_id not in existing_ids_set:
await pipe.rpush(new_mark_key, msg_id)
existing_ids_set.add(msg_id)
# Maintain the marks index
await pipe.sadd(self._get_marks_index_key(), new_mark)
# Count update only if we actually did something
if old_mark is not None or new_mark is not None:
updated_count += 1
# Clean up old_mark only if we removed ALL messages from it
if old_mark is not None and removing_all_from_old_mark:
old_mark_key = self._get_mark_key(old_mark)
# After lrem operations, the old mark list will be empty
# Delete the mark key and remove from index
await pipe.delete(old_mark_key)
await pipe.srem(self._get_marks_index_key(), old_mark)
await self._refresh_session_ttl(pipe=pipe)
await pipe.execute()
return updated_count
async def close(self, close_connection_pool: bool | None = None) -> None:
"""Close the Redis client connection.
Args:
close_connection_pool (`bool | None`, optional):
Decides whether to close the connection pool used by this
Redis client, overriding Redis.auto_close_connection_pool.
By default, let Redis.auto_close_connection_pool decide
whether to close the connection pool
"""
await self._client.aclose(close_connection_pool=close_connection_pool)
async def __aenter__(self) -> "RedisMemory":
"""Enter the async context manager.
Returns:
`RedisMemory`:
The memory instance itself.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the session.
Args:
exc_type (`type[BaseException] | None`):
The exception type if an exception was raised.
exc_value (`BaseException | None`):
The exception instance if an exception was raised.
traceback (`Any`):
The traceback object if an exception was raised.
"""
await self.close()

View File

@@ -0,0 +1,873 @@
# -*- coding: utf-8 -*-
"""The SQLAlchemy database storage module, which supports storing messages in
a SQL database using SQLAlchemy ORM (e.g., SQLite, PostgreSQL, MySQL)."""
from typing import Any
from sqlalchemy import (
Column,
String,
JSON,
BigInteger,
ForeignKey,
select,
delete,
update,
func,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
)
from sqlalchemy.orm import declarative_base, relationship
from ._base import MemoryBase
from ...message import Msg
Base: Any = declarative_base()
class AsyncSQLAlchemyMemory(MemoryBase):
"""The SQLAlchemy memory storage class for storing messages in a SQL
database using SQLAlchemy ORM, such as SQLite, PostgreSQL, MySQL, etc.
.. note:: All the operations in this class are within a specific session
and user context, identified by `session_id` and `user_id`. Cross-session
or cross-user operations are not supported. For example, the
`remove_messages` method will only remove messages that belong to the
specified `session_id` and `user_id`.
"""
class MessageTable(Base):
"""The default message table definition."""
__tablename__ = "message"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The id column, we use the f"{user_id}-{session_id}-{message_id}"
as the primary key to ensure uniqueness across users and sessions."""
msg = Column(JSON, nullable=False)
"""The message JSON content column"""
session = relationship(
"SessionTable",
back_populates="messages",
)
"""The foreign key to the session id relationship"""
session_id = Column(
String(255),
ForeignKey("session.id"),
nullable=False,
)
"""The foreign key to the session id"""
index = Column(BigInteger, nullable=False, index=True)
"""The index column for ordering messages, so that we can retrieve
messages in the order they were added."""
class MessageMarkTable(Base):
"""The default message mark table definition."""
__tablename__ = "message_mark"
"""The table name"""
msg_id = Column(
String(255),
ForeignKey("message.id", ondelete="CASCADE"),
primary_key=True,
)
"""The message id column"""
mark = Column(String(255), primary_key=True)
"""The mark column"""
class SessionTable(Base):
"""The default session table definition."""
__tablename__ = "session"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The session id column"""
user = relationship("UserTable", back_populates="sessions")
"""The foreign key to the user id relationship"""
user_id = Column(String(255), ForeignKey("users.id"), nullable=False)
"""The foreign key to the user id"""
messages = relationship("MessageTable", back_populates="session")
"""The relationship to messages"""
class UserTable(Base):
"""The default user table definition."""
__tablename__ = "users"
"""The table name"""
id = Column(String(255), primary_key=True)
"""The user id column"""
sessions = relationship("SessionTable", back_populates="user")
"""The relationship to sessions"""
def __init__(
self,
engine_or_session: AsyncEngine | AsyncSession,
session_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Initialize the SqlAlchemyDBStorage with a SQLAlchemy session.
Args:
engine_or_session (`AsyncEngine | AsyncSession`):
The SQLAlchemy asynchronous engine or session to use for
database operations. If you're using a connection pool, maybe
you want to pass in an `AsyncSession` instance.
session_id (`str | None`, optional):
The session ID for the messages. If `None`, a default session
ID will be used.
user_id (`str | None`, optional):
The user ID for the messages. If `None`, a default user ID
will be used.
Raises:
`ValueError`:
If the `engine` parameter is not an instance of
`sqlalchemy.ext.asyncio.AsyncEngine` or `sqlalchemy.
ext.asyncio.AsyncSession`.
"""
super().__init__()
self._db_session: AsyncSession | None = None
if isinstance(engine_or_session, AsyncEngine):
self._session_factory = async_sessionmaker(
bind=engine_or_session,
expire_on_commit=False,
)
elif isinstance(engine_or_session, AsyncSession):
self._session_factory = None
self._db_session = engine_or_session
else:
raise ValueError(
"The 'engine_or_session' parameter must be an instance of "
"sqlalchemy.ext.asyncio.AsyncEngine.",
)
self.session_id = session_id or "default_session"
self.user_id = user_id or "default_user"
# Flag to track if tables and records have been initialized
self._initialized = False
def _make_message_id(self, msg_id: str) -> str:
"""Generate a composite primary key for a message.
Args:
msg_id (`str`):
The original message ID.
Returns:
`str`:
The composite primary key in the format
"{user_id}-{session_id}-{message_id}".
"""
return f"{self.user_id}-{self.session_id}-{msg_id}"
@property
def session(self) -> AsyncSession:
"""Get the current database session, creating one if it doesn't exist.
Returns:
`AsyncSession`:
The current database session.
Note:
- If an external session was provided, it will be returned as-is
- If using internal session factory, a new session will be created
if the current one is None or inactive, and _initialized flag
will be reset to ensure proper re-initialization
"""
# External session: return as-is (managed by caller)
if self._session_factory is None:
return self._db_session
# Internal session: check validity and recreate if needed
if self._db_session is None or not self._db_session.is_active:
self._db_session = self._session_factory()
# Reset initialized flag when creating new session
self._initialized = False
return self._db_session
async def _create_table(self) -> None:
"""Create tables in database."""
# Skip if already initialized
if self._initialized:
return
# Obtain the engine first
engine: AsyncEngine = self.session.bind
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Track if we need to commit
needs_commit = False
# Create user record if not exists
result = await self.session.execute(
select(self.UserTable).filter(
self.UserTable.id == self.user_id,
),
)
user_record = result.scalar_one_or_none()
if user_record is None:
user_record = self.UserTable(
id=self.user_id,
)
self.session.add(user_record)
needs_commit = True
# Create session record if not exists
result = await self.session.execute(
select(self.SessionTable).filter(
self.SessionTable.id == self.session_id,
),
)
session_record = result.scalar_one_or_none()
if session_record is None:
session_record = self.SessionTable(
id=self.session_id,
user_id=self.user_id,
)
self.session.add(session_record)
needs_commit = True
# Commit once if any records were added
if needs_commit:
await self.session.commit()
# Mark as initialized
self._initialized = True
async def get_memory(
self,
mark: str | None = None,
exclude_mark: str | None = None,
prepend_summary: bool = True,
**kwargs: Any,
) -> list[Msg]:
"""Get the messages from the memory by mark (if provided). Otherwise,
get all messages.
.. note:: If `mark` and `exclude_mark` are both provided, the messages
will be filtered by both arguments.
.. note:: `mark` and `exclude_mark` should not overlap.
Args:
mark (`str | None`, optional):
The mark to filter messages. If `None`, return all messages.
exclude_mark (`str | None`, optional):
The mark to exclude messages. If provided, messages with
this mark will be excluded from the results.
prepend_summary (`bool`, defaults to True):
Whether to prepend the compressed summary as a message
Raises:
`TypeError`:
If the provided mark is not a string or None.
Returns:
`list[Msg]`:
The list of messages retrieved from the storage.
"""
# Type checks
if mark is not None and not isinstance(mark, str):
raise TypeError(
f"The mark should be a string or None, but got {type(mark)}.",
)
if exclude_mark is not None and not isinstance(exclude_mark, str):
raise TypeError(
f"The exclude_mark should be a string or None, but got "
f"{type(exclude_mark)}.",
)
await self._create_table()
# Step 1: First filter by session_id to narrow down the dataset
# This ensures the database uses the session_id index first
base_query = select(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
)
# Step 2: Apply mark filtering if provided
if mark:
# Join with mark table only on the session-filtered messages
base_query = base_query.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
).filter(
self.MessageMarkTable.mark == mark,
)
# Step 3: Apply exclude_mark filtering if provided
if exclude_mark:
# Use a subquery to find message IDs with the exclude_mark
# within the current session only
exclude_subquery = (
select(self.MessageMarkTable.msg_id)
.filter(
self.MessageMarkTable.msg_id.in_(
select(self.MessageTable.id).filter(
self.MessageTable.session_id == self.session_id,
),
),
self.MessageMarkTable.mark == exclude_mark,
)
.scalar_subquery()
)
# Exclude messages whose IDs are in the subquery
base_query = base_query.filter(
self.MessageTable.id.notin_(exclude_subquery),
)
# Step 4: Order by index to maintain message order
query = base_query.order_by(self.MessageTable.index)
result = await self.session.execute(query)
results = result.scalars().all()
msgs = [Msg.from_dict(result.msg) for result in results]
if prepend_summary and self._compressed_summary:
return [
Msg(
"user",
self._compressed_summary,
"user",
),
*msgs,
]
return msgs
async def add(
self,
memories: Msg | list[Msg] | None,
marks: str | list[str] | None = None,
skip_duplicated: bool = True,
**kwargs: Any,
) -> None:
"""Add message into the storage with the given mark (if provided).
Args:
memories (`Msg | list[Msg] | None`):
The message(s) to be added.
marks (`str | list[str] | None`, optional):
The mark(s) to associate with the message(s). If `None`, no
mark is associated.
skip_duplicated (`bool`, defaults to `True`):
If `True`, skip messages with duplicate IDs that already exist
in the storage. If `False`, raise an `IntegrityError` when
attempting to add a message with an existing ID.
Raises:
`IntegrityError`:
If a message with the same ID already exists in the storage
and `skip_duplicated` is set to `False`.
"""
if memories is None:
return
# Type checking
if isinstance(memories, Msg):
memories = [memories]
elif not (
isinstance(memories, list)
and all(isinstance(_, Msg) for _ in memories)
):
raise TypeError(
"The 'memories' parameter must be a Msg instance or a list of "
f"Msg instances, but got {type(memories)}.",
)
if isinstance(marks, str):
marks = [marks]
elif marks is not None and not (
isinstance(marks, list) and all(isinstance(m, str) for m in marks)
):
raise TypeError(
"The 'marks' parameter must be a string or a list of strings, "
f"but got {type(marks)}.",
)
# Create table if not exists
await self._create_table()
# If skip_duplicated is True, filter out existing messages
messages_to_add = memories
if skip_duplicated:
existing_msg_ids = set()
result = await self.session.execute(
select(self.MessageTable.id).filter(
self.MessageTable.id.in_(
[self._make_message_id(m.id) for m in memories],
),
),
)
existing_msg_ids = {row[0] for row in result.fetchall()}
messages_to_add = [
m
for m in memories
if self._make_message_id(m.id) not in existing_msg_ids
]
# If all messages are duplicates, return early
if not messages_to_add:
return
# Get the starting index once to avoid race conditions
start_index = await self._get_next_index()
# Add messages to message table
for i, m in enumerate(messages_to_add):
message_record = self.MessageTable(
id=self._make_message_id(m.id),
msg=m.to_dict(),
session_id=self.session_id,
index=start_index + i,
)
self.session.add(message_record)
# Create mark records if marks are provided (use bulk insert)
if marks:
mark_records = [
{"msg_id": self._make_message_id(msg.id), "mark": mark}
for msg in messages_to_add
for mark in marks
]
if mark_records:
if skip_duplicated:
# Query existing mark combinations to avoid duplicates
result = await self.session.execute(
select(
self.MessageMarkTable.msg_id,
self.MessageMarkTable.mark,
),
)
existing_marks = {
(row[0], row[1]) for row in result.fetchall()
}
# Filter out existing mark combinations
mark_records = [
r
for r in mark_records
if (r["msg_id"], r["mark"]) not in existing_marks
]
if mark_records:
await self.session.run_sync(
lambda session: session.bulk_insert_mappings(
self.MessageMarkTable,
mark_records,
),
)
await self.session.commit()
async def _get_next_index(self) -> int:
"""Get the next index for a new message in the current session.
Returns:
`int`:
The next index value.
"""
result = await self.session.execute(
select(self.MessageTable.index)
.filter(self.MessageTable.session_id == self.session_id)
.order_by(self.MessageTable.index.desc())
.limit(1),
)
max_index = result.scalar_one_or_none()
return (max_index + 1) if max_index is not None else 0
async def size(self) -> int:
"""Get the size of the messages in the storage."""
result = await self.session.execute(
select(func.count(self.MessageTable.id)).filter(
self.MessageTable.session_id == self.session_id,
),
)
return result.scalar_one()
async def clear(self) -> None:
"""Clear all messages from the storage."""
# Delete all marks for messages in this session
await self.session.execute(
delete(self.MessageMarkTable).where(
self.MessageMarkTable.msg_id.in_(
select(self.MessageTable.id).filter(
self.MessageTable.session_id == self.session_id,
),
),
),
)
# Then delete all messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
),
)
await self.session.commit()
async def delete_by_mark(
self,
mark: str | list[str],
**kwargs: Any,
) -> int:
"""Remove messages from the storage by their marks.
Args:
mark (`str | list[str]`):
The mark(s) of the messages to be removed.
Returns:
`int`:
The number of messages removed.
"""
if isinstance(mark, str):
mark = [mark]
# First, find message IDs that have the specified marks
query = (
select(self.MessageTable.id)
.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
)
.filter(
self.MessageTable.session_id == self.session_id,
self.MessageMarkTable.mark.in_(mark),
)
)
result = await self.session.execute(query)
msg_ids = [row[0] for row in result.all()]
if not msg_ids:
return 0
# Store the count before deletion
deleted_count = len(msg_ids)
# Delete marks first
await self.session.execute(
delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
),
)
# Then delete the messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
self.MessageTable.id.in_(msg_ids),
),
)
await self.session.commit()
return deleted_count
async def delete(
self,
msg_ids: list[str],
**kwargs: Any,
) -> int:
"""Remove message(s) from the storage by their IDs.
.. note:: Although MessageMarkTable has CASCADE delete on foreign key,
we explicitly delete marks first for reliability across all database
engines and configurations. SQLAlchemy's bulk delete bypasses
ORM-level cascades, and SQLite requires foreign keys to be
explicitly enabled.
Args:
msg_ids (`list[str]`):
The list of message IDs to be removed.
Returns:
`int`:
The number of messages removed.
"""
# Convert to composite keys
composite_ids = [self._make_message_id(msg_id) for msg_id in msg_ids]
if not composite_ids:
return 0
# Store the count before deletion
deleted_count = len(composite_ids)
# Delete related marks first (explicit cleanup for reliability)
await self.session.execute(
delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(composite_ids),
),
)
# Then delete the messages
await self.session.execute(
delete(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
self.MessageTable.id.in_(composite_ids),
),
)
await self.session.commit()
return deleted_count
async def update_messages_mark(
self,
new_mark: str | None,
old_mark: str | None = None,
msg_ids: list[str] | None = None,
) -> int:
"""A unified method to update marks of messages in the storage (add,
remove, or change marks).
- If `msg_ids` is provided, the update will be applied to the messages
with the specified IDs.
- If `old_mark` is provided, the update will be applied to the
messages with the specified old mark. Otherwise, the `new_mark` will
be added to all messages (or those filtered by `msg_ids`).
- If `new_mark` is `None`, the mark will be removed from the messages.
Args:
new_mark (`str | None`, optional):
The new mark to set for the messages. If `None`, the mark
will be removed.
old_mark (`str | None`, optional):
The old mark to filter messages. If `None`, this constraint
is ignored.
msg_ids (`list[str] | None`, optional):
The list of message IDs to be updated. If `None`, this
constraint is ignored.
Returns:
`int`:
The number of messages updated.
"""
# Type checking
if new_mark is not None and not isinstance(new_mark, str):
raise ValueError(
f"The 'new_mark' parameter must be a string or None, "
f"but got {type(new_mark)}.",
)
if old_mark is not None and not isinstance(old_mark, str):
raise ValueError(
f"The 'old_mark' parameter must be a string or None, "
f"but got {type(old_mark)}.",
)
if msg_ids is not None and not (
isinstance(msg_ids, list)
and all(isinstance(_, str) for _ in msg_ids)
):
raise ValueError(
f"The 'msg_ids' parameter must be a list of strings or None, "
f"but got {type(msg_ids)}.",
)
# First obtain the message ids that belong to this session
query = select(self.MessageTable).filter(
self.MessageTable.session_id == self.session_id,
)
# Filter by msg_ids if provided
if msg_ids is not None:
# Convert to composite keys
composite_ids = [
self._make_message_id(msg_id) for msg_id in msg_ids
]
query = query.filter(self.MessageTable.id.in_(composite_ids))
# Filter by old_mark if provided
if old_mark is not None:
query = query.join(
self.MessageMarkTable,
self.MessageTable.id == self.MessageMarkTable.msg_id,
).filter(self.MessageMarkTable.mark == old_mark)
# Obtain the message records
result = await self.session.execute(query)
msg_ids = [str(_.id) for _ in result.scalars().all()]
# Return early if no messages found
if not msg_ids:
return 0
if new_mark:
if old_mark:
# Replace old_mark with new_mark
return await self._replace_message_mark(
msg_ids=msg_ids,
old_mark=old_mark,
new_mark=new_mark,
)
# Add new_mark to the messages
return await self._add_message_mark(
msg_ids=msg_ids,
mark=new_mark,
)
# Remove all marks from the messages
return await self._remove_message_mark(
msg_ids=msg_ids,
old_mark=old_mark,
)
async def _replace_message_mark(
self,
msg_ids: list[str],
old_mark: str,
new_mark: str,
) -> int:
"""Replace the old mark with the new mark for the given messages by
updating records in the message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be updated.
old_mark (`str`):
The old mark to be replaced.
new_mark (`str`):
The new mark to be set.
Returns:
`int`:
The number of messages updated.
"""
await self.session.execute(
update(self.MessageMarkTable)
.filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
self.MessageMarkTable.mark == old_mark,
)
.values(mark=new_mark),
)
await self.session.commit()
return len(msg_ids)
async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int:
"""Mark the messages with the given mark by adding records to the
message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be marked.
mark (`str`):
The mark to be added to the messages.
Returns:
`int`:
The number of messages marked.
"""
# Use bulk insert for better performance
mark_records = [{"msg_id": msg_id, "mark": mark} for msg_id in msg_ids]
if mark_records:
await self.session.run_sync(
lambda session: session.bulk_insert_mappings(
self.MessageMarkTable,
mark_records,
),
)
await self.session.commit()
return len(msg_ids)
async def _remove_message_mark(
self,
msg_ids: list[str],
old_mark: str | None,
) -> int:
"""Remove marks from the messages by deleting records from the
message_mark table.
Args:
msg_ids (`list[str]`):
The list of message IDs to be unmarked.
old_mark (`str | None`):
The old mark to be removed. If `None`, all marks will be
removed from the messages.
Returns:
`int`:
The number of messages unmarked.
"""
delete_query = delete(self.MessageMarkTable).filter(
self.MessageMarkTable.msg_id.in_(msg_ids),
)
if old_mark:
delete_query = delete_query.filter(
self.MessageMarkTable.mark == old_mark,
)
await self.session.execute(delete_query)
await self.session.commit()
return len(msg_ids)
async def close(self) -> None:
"""Close the database session."""
if self._db_session and self._db_session.is_active:
await self._db_session.close()
self._db_session = None
self._initialized = False
async def __aenter__(self) -> "AsyncSQLAlchemyMemory":
"""Enter the async context manager.
Returns:
`AsyncSQLAlchemyMemory`:
The memory instance itself.
"""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: Any,
) -> None:
"""Exit the async context manager and close the session.
Args:
exc_type (`type[BaseException] | None`):
The exception type if an exception was raised.
exc_value (`BaseException | None`):
The exception instance if an exception was raised.
traceback (`Any`):
The traceback object if an exception was raised.
"""
await self.close()

View File

@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
"""The message module in agentscope."""
from ._message_block import (
ContentBlock,
TextBlock,
ThinkingBlock,
ToolUseBlock,
ToolResultBlock,
ImageBlock,
AudioBlock,
VideoBlock,
Base64Source,
URLSource,
)
from ._message_base import Msg
__all__ = [
"TextBlock",
"ThinkingBlock",
"Base64Source",
"URLSource",
"ImageBlock",
"AudioBlock",
"VideoBlock",
"ToolUseBlock",
"ToolResultBlock",
"ContentBlock",
"Msg",
]

View File

@@ -0,0 +1,241 @@
# -*- coding: utf-8 -*-
"""The message class in agentscope."""
from datetime import datetime
from typing import Literal, List, overload, Sequence
import shortuuid
from ._message_block import (
TextBlock,
ToolUseBlock,
ImageBlock,
AudioBlock,
ContentBlock,
VideoBlock,
ToolResultBlock,
ContentBlockTypes,
)
from ..types import JSONSerializableObject
class Msg:
"""The message class in agentscope."""
def __init__(
self,
name: str,
content: str | Sequence[ContentBlock],
role: Literal["user", "assistant", "system"],
metadata: dict[str, JSONSerializableObject] | None = None,
timestamp: str | None = None,
invocation_id: str | None = None,
) -> None:
"""Initialize the Msg object.
Args:
name (`str`):
The name of the message sender.
content (`str | list[ContentBlock]`):
The content of the message.
role (`Literal["user", "assistant", "system"]`):
The role of the message sender.
metadata (`dict[str, JSONSerializableObject] | None`, optional):
The metadata of the message, e.g. structured output.
timestamp (`str | None`, optional):
The created timestamp of the message. If not given, the
timestamp will be set automatically.
invocation_id (`str | None`, optional):
The related API invocation id, if any. This is useful for
tracking the message in the context of an API call.
"""
self.name = name
assert isinstance(
content,
(list, str),
), "The content must be a string or a list of content blocks."
self.content = content
assert role in ["user", "assistant", "system"]
self.role = role
self.metadata = metadata or {}
self.id = shortuuid.uuid()
self.timestamp = (
timestamp
or datetime.now().strftime(
"%Y-%m-%d %H:%M:%S.%f",
)[:-3]
)
self.invocation_id = invocation_id
def to_dict(self) -> dict:
"""Convert the message into JSON dict data."""
return {
"id": self.id,
"name": self.name,
"role": self.role,
"content": self.content,
"metadata": self.metadata,
"timestamp": self.timestamp,
}
@classmethod
def from_dict(cls, json_data: dict) -> "Msg":
"""Load a message object from the given JSON data."""
new_obj = cls(
name=json_data["name"],
content=json_data["content"],
role=json_data["role"],
metadata=json_data.get("metadata", None),
timestamp=json_data.get("timestamp", None),
invocation_id=json_data.get("invocation_id", None),
)
new_obj.id = json_data.get("id", new_obj.id)
return new_obj
def has_content_blocks(
self,
block_type: Literal[
"text",
"tool_use",
"tool_result",
"image",
"audio",
"video",
]
| None = None,
) -> bool:
"""Check if the message has content blocks of the given type.
Args:
block_type (Literal["text", "tool_use", "tool_result", "image", \
"audio", "video"] | None, defaults to None):
The type of the block to be checked. If `None`, it will
check if there are any content blocks.
"""
return len(self.get_content_blocks(block_type)) > 0
def get_text_content(self, separator: str = "\n") -> str | None:
"""Get the pure text blocks from the message content.
Args:
separator (`str`, defaults to `\n`):
The separator to use when concatenating multiple text blocks.
Defaults to newline character.
Returns:
`str | None`:
The concatenated text content, or `None` if there is no text
content.
"""
if isinstance(self.content, str):
return self.content
gathered_text = []
for block in self.content:
if block.get("type") == "text":
gathered_text.append(block["text"])
if gathered_text:
return separator.join(gathered_text)
return None
@overload
def get_content_blocks(
self,
block_type: Literal["text"],
) -> Sequence[TextBlock]:
...
@overload
def get_content_blocks(
self,
block_type: Literal["tool_use"],
) -> Sequence[ToolUseBlock]:
...
@overload
def get_content_blocks(
self,
block_type: Literal["tool_result"],
) -> Sequence[ToolResultBlock]:
...
@overload
def get_content_blocks(
self,
block_type: Literal["image"],
) -> Sequence[ImageBlock]:
...
@overload
def get_content_blocks(
self,
block_type: Literal["audio"],
) -> Sequence[AudioBlock]:
...
@overload
def get_content_blocks(
self,
block_type: Literal["video"],
) -> Sequence[VideoBlock]:
...
@overload
def get_content_blocks(
self,
block_type: None = None,
) -> Sequence[ContentBlock]:
...
def get_content_blocks(
self,
block_type: ContentBlockTypes | List[ContentBlockTypes] | None = None,
) -> Sequence[ContentBlock]:
"""Get the content in block format. If the content is a string,
it will be converted to a text block.
Args:
block_type (`ContentBlockTypes | List[ContentBlockTypes] | None`, \
optional):
The type of the block to be extracted. If `None`, all blocks
will be returned.
Returns:
`List[ContentBlock]`:
The content blocks.
"""
blocks = []
if isinstance(self.content, str):
blocks.append(
TextBlock(type="text", text=self.content),
)
else:
blocks = self.content or []
if isinstance(block_type, str):
blocks = [_ for _ in blocks if _["type"] == block_type]
elif isinstance(block_type, list):
blocks = [_ for _ in blocks if _["type"] in block_type]
return blocks
def __repr__(self) -> str:
"""Get the string representation of the message."""
return (
f"Msg(id='{self.id}', "
f"name='{self.name}', "
f"content={repr(self.content)}, "
f"role='{self.role}', "
f"metadata={repr(self.metadata)}, "
f"timestamp='{self.timestamp}', "
f"invocation_id='{self.invocation_id}')"
)

View File

@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
# pylint: disable=R0901
"""The content blocks of messages"""
from typing import Literal, List
from typing_extensions import TypedDict, Required
class TextBlock(TypedDict, total=False):
"""The text block."""
type: Required[Literal["text"]]
"""The type of the block"""
text: str
"""The text content"""
class ThinkingBlock(TypedDict, total=False):
"""The thinking block."""
type: Required[Literal["thinking"]]
"""The type of the block"""
thinking: str
class Base64Source(TypedDict, total=False):
"""The base64 source"""
type: Required[Literal["base64"]]
"""The type of the src, must be `base64`"""
media_type: Required[str]
"""The media type of the data, e.g. `image/jpeg` or `audio/mpeg`"""
data: Required[str]
"""The base64 data, in format of RFC 2397"""
class URLSource(TypedDict, total=False):
"""The URL source"""
type: Required[Literal["url"]]
"""The type of the src, must be `url`"""
url: Required[str]
"""The URL of the image or audio"""
class ImageBlock(TypedDict, total=False):
"""The image block"""
type: Required[Literal["image"]]
"""The type of the block"""
source: Required[Base64Source | URLSource]
"""The src of the image"""
class AudioBlock(TypedDict, total=False):
"""The audio block"""
type: Required[Literal["audio"]]
"""The type of the block"""
source: Required[Base64Source | URLSource]
"""The src of the audio"""
class VideoBlock(TypedDict, total=False):
"""The video block"""
type: Required[Literal["video"]]
"""The type of the block"""
source: Required[Base64Source | URLSource]
"""The src of the audio"""
class ToolUseBlock(TypedDict, total=False):
"""The tool use block."""
type: Required[Literal["tool_use"]]
"""The type of the block, must be `tool_use`"""
id: Required[str]
"""The identity of the tool call"""
name: Required[str]
"""The name of the tool"""
input: Required[dict[str, object]]
"""The input of the tool"""
raw_input: str
"""The raw string input of the tool from the model API"""
class ToolResultBlock(TypedDict, total=False):
"""The tool result block."""
type: Required[Literal["tool_result"]]
"""The type of the block"""
id: Required[str]
"""The identity of the tool call result"""
output: Required[
str | List[TextBlock | ImageBlock | AudioBlock | VideoBlock]
]
"""The output of the tool function"""
name: Required[str]
"""The name of the tool function"""
# The content block
ContentBlock = (
ToolUseBlock
| ToolResultBlock
| TextBlock
| ThinkingBlock
| ImageBlock
| AudioBlock
| VideoBlock
)
ContentBlockTypes = Literal[
"text",
"thinking",
"tool_use",
"tool_result",
"image",
"audio",
"video",
]

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""The model module."""
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._dashscope_model import DashScopeChatModel
from ._openai_model import OpenAIChatModel
from ._anthropic_model import AnthropicChatModel
from ._ollama_model import OllamaChatModel
from ._gemini_model import GeminiChatModel
from ._trinity_model import TrinityChatModel
__all__ = [
"ChatModelBase",
"ChatResponse",
"DashScopeChatModel",
"OpenAIChatModel",
"AnthropicChatModel",
"OllamaChatModel",
"GeminiChatModel",
"TrinityChatModel",
]

View File

@@ -0,0 +1,590 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches, too-many-statements
"""The Anthropic API model classes."""
import copy
import json
import warnings
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
TYPE_CHECKING,
List,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import (
_json_loads_with_repair,
_create_tool_from_base_model,
)
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types._json import JSONSerializableObject
if TYPE_CHECKING:
from anthropic.types.message import Message
from anthropic import AsyncStream
else:
Message = "anthropic.types.message.Message"
AsyncStream = "anthropic.AsyncStream"
class AnthropicChatModel(ChatModelBase):
"""The Anthropic model wrapper for AgentScope."""
def __init__(
self,
model_name: str,
api_key: str | None = None,
max_tokens: int = 2048,
stream: bool = True,
thinking: dict | None = None,
stream_tool_parsing: bool = True,
client_kwargs: dict[str, JSONSerializableObject] | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Anthropic chat model.
Args:
model_name (`str`):
The model names.
api_key (`str`):
The anthropic API key.
stream (`bool`):
The streaming output or not
max_tokens (`int`):
Limit the maximum token count the model can generate.
thinking (`dict | None`, default `None`):
Configuration for Claude's internal reasoning process.
.. code-block:: python
:caption: Example of thinking
{
"type": "enabled" | "disabled",
"budget_tokens": 1024
}
stream_tool_parsing (`bool`, default to `True`):
Whether to parse incomplete tool use JSON during streaming
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
is repaired to valid dicts ({"a": "x"}) in real-time for
immediate tool function input. Otherwise, the input field
remains {} until the final chunk arrives.
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments to initialize the Anthropic client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Anthropic API generation,
e.g. `temperature`, `seed`.
**kwargs (`Any`):
Additional keyword arguments.
"""
# Handle deprecated client_args parameter from kwargs
client_args = kwargs.pop("client_args", None)
if client_args is not None and client_kwargs is not None:
raise ValueError(
"Cannot specify both 'client_args' and 'client_kwargs'. "
"Please use only 'client_kwargs' (client_args is deprecated).",
)
if client_args is not None:
logger.warning(
"The parameter 'client_args' is deprecated and will be "
"removed in a future version. Please use 'client_kwargs' "
"instead. Automatically converting 'client_args' to "
"'client_kwargs'.",
)
client_kwargs = client_args
if kwargs:
logger.warning(
"Unknown keyword arguments: %s. These will be ignored.",
list(kwargs.keys()),
)
try:
import anthropic
except ImportError as e:
raise ImportError(
"Please install the `anthropic` package by running "
"`pip install anthropic`.",
) from e
super().__init__(model_name, stream)
self.client = anthropic.AsyncAnthropic(
api_key=api_key,
**(client_kwargs or {}),
)
self.max_tokens = max_tokens
self.thinking = thinking
self.stream_tool_parsing = stream_tool_parsing
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**generate_kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from Anthropic chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that in format of:
.. code-block:: python
:caption: Example of tools JSON schemas
[
{
"type": "function",
"function": {
"name": "xxx",
"description": "xxx",
"parameters": {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "..."
},
# Add more parameters as needed
},
"required": ["param1"]
}
},
# More schemas here
]
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool
name. For more details, please refer to
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
**generate_kwargs (`Any`):
The keyword arguments for Anthropic chat completions API,
e.g. `temperature`, `top_p`, etc. Please
refer to the Anthropic API documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the Anthropic chat completions API."""
kwargs: dict[str, Any] = {
"model": self.model_name,
"max_tokens": self.max_tokens,
"stream": self.stream,
**self.generate_kwargs,
**generate_kwargs,
}
if self.thinking and "thinking" not in kwargs:
kwargs["thinking"] = self.thinking
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
# Handle deprecated "any" option with warning
if tool_choice == "any":
warnings.warn(
'"any" is deprecated and will be removed in a future '
"version.",
DeprecationWarning,
)
tool_choice = "required"
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
format_tool = _create_tool_from_base_model(structured_model)
kwargs["tools"] = self._format_tools_json_schemas(
[format_tool],
)
kwargs["tool_choice"] = self._format_tool_choice(
format_tool["function"]["name"],
)
# Extract the system message
if messages[0]["role"] == "system":
kwargs["system"] = messages[0]["content"]
messages = messages[1:]
kwargs["messages"] = messages
start_datetime = datetime.now()
response = await self.client.messages.create(**kwargs)
if self.stream:
return self._parse_anthropic_stream_completion_response(
start_datetime,
response,
structured_model,
)
# Non-streaming response
parsed_response = await self._parse_anthropic_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_anthropic_completion_response(
self,
start_datetime: datetime,
response: Message,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Anthropic Message object, extract the content blocks and
usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`Message`):
Anthropic Message object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = []
metadata = None
if hasattr(response, "content") and response.content:
for content_block in response.content:
if (
hasattr(content_block, "type")
and content_block.type == "thinking"
):
thinking_block = ThinkingBlock(
type="thinking",
thinking=content_block.thinking,
)
thinking_block["signature"] = content_block.signature
content_blocks.append(thinking_block)
elif (
hasattr(content_block, "type")
and content_block.type == "text"
):
content_blocks.append(
TextBlock(
type="text",
text=content_block.text,
),
)
elif (
hasattr(content_block, "type")
and content_block.type == "tool_use"
):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=content_block.id,
name=content_block.name,
input=content_block.input,
),
)
if structured_model:
metadata = content_block.input
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
async def _parse_anthropic_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncStream,
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Anthropic streaming response, extract the content blocks
and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncStream`):
Anthropic AsyncStream object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in
the streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
usage = None
text_buffer = ""
thinking_buffer = ""
thinking_signature = ""
tool_calls = OrderedDict()
tool_call_buffers = {}
last_input_objs = {} # Store last input_obj for each tool_call
res = None
metadata = None
# Record the last yielded content to parse the tools' input
last_content = None
async for event in response:
content_changed = False
thinking_changed = False
if event.type == "message_start":
message = event.message
if message.usage:
usage = ChatUsage(
input_tokens=message.usage.input_tokens,
output_tokens=getattr(
message.usage,
"output_tokens",
0,
),
time=(datetime.now() - start_datetime).total_seconds(),
)
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
block_index = event.index
tool_block = event.content_block
tool_calls[block_index] = {
"type": "tool_use",
"id": tool_block.id,
"name": tool_block.name,
"input": "",
}
tool_call_buffers[block_index] = ""
content_changed = True
elif event.type == "content_block_delta":
block_index = event.index
delta = event.delta
if delta.type == "text_delta":
text_buffer += delta.text
content_changed = True
elif delta.type == "thinking_delta":
thinking_buffer += delta.thinking
thinking_changed = True
elif delta.type == "signature_delta":
thinking_signature = delta.signature
elif (
delta.type == "input_json_delta"
and block_index in tool_calls
):
tool_call_buffers[block_index] += delta.partial_json or ""
tool_calls[block_index]["input"] = tool_call_buffers[
block_index
]
content_changed = True
elif event.type == "message_delta":
if event.usage and usage:
usage.output_tokens = event.usage.output_tokens
if (thinking_changed or content_changed) and usage:
contents: list = []
if thinking_buffer:
thinking_block = ThinkingBlock(
type="thinking",
thinking=thinking_buffer,
)
thinking_block["signature"] = thinking_signature
contents.append(thinking_block)
if text_buffer:
contents.append(
TextBlock(
type="text",
text=text_buffer,
),
)
for block_index, tool_call in tool_calls.items():
input_str = tool_call["input"]
tool_id = tool_call["id"]
# If parsing the tool input in streaming mode
if self.stream_tool_parsing:
repaired_input = _json_loads_with_repair(
input_str or "{}",
)
# If the new repaired input is shorter than one in the
# last chunk, use the last one to avoid regression
last_input = last_input_objs.get(tool_id, {})
if len(json.dumps(last_input)) > len(
json.dumps(repaired_input),
):
repaired_input = last_input
last_input_objs[tool_id] = repaired_input
else:
repaired_input = {}
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=repaired_input,
raw_input=input_str,
),
)
if structured_model:
metadata = repaired_input
if contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
last_content = copy.deepcopy(contents)
# If stream_tool_parsing is False, yield last contents
if not self.stream_tool_parsing and last_content and tool_calls:
metadata = None
# Update tool use blocks in last_contents inplace
for block in last_content:
if block.get("type") == "tool_use":
block["input"] = input_obj = _json_loads_with_repair(
block.get("raw_input") or "{}",
)
if structured_model:
metadata = input_obj
yield ChatResponse(
content=last_content,
usage=usage,
metadata=metadata,
)
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the JSON schemas of the tool functions to the format that
Anthropic API expects."""
formatted_schemas = []
for schema in schemas:
assert (
"function" in schema
), f"Invalid schema: {schema}, expect key 'function'."
assert "name" in schema["function"], (
f"Invalid schema: {schema}, "
"expect key 'name' in 'function' field."
)
formatted_schemas.append(
{
"name": schema["function"]["name"],
"description": schema["function"].get("description", ""),
"input_schema": schema["function"].get("parameters", {}),
},
)
return formatted_schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "required"] | str | None,
) -> dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool
name. For more details, please refer to
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
type_mapping = {
"auto": {"type": "auto"},
"none": {"type": "none"},
"required": {"type": "any"},
}
if tool_choice in type_mapping:
return type_mapping[tool_choice]
return {"type": "tool", "name": tool_choice}

View File

@@ -0,0 +1,632 @@
# -*- coding: utf-8 -*-
"""The dashscope API model classes."""
import copy
import collections
import json
import os
import warnings
from datetime import datetime
from http import HTTPStatus
from typing import (
Any,
AsyncGenerator,
Generator,
Union,
TYPE_CHECKING,
List,
Literal,
Type,
)
from pydantic import BaseModel
from aioitertools import iter as giter
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._model_usage import ChatUsage
from .._utils._common import (
_json_loads_with_repair,
_create_tool_from_base_model,
)
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types import JSONSerializableObject
from .._logging import logger
if TYPE_CHECKING:
from dashscope.api_entities.dashscope_response import GenerationResponse
from dashscope.api_entities.dashscope_response import (
MultiModalConversationResponse,
)
else:
GenerationResponse = (
"dashscope.api_entities.dashscope_response.GenerationResponse"
)
MultiModalConversationResponse = (
"dashscope.api_entities.dashscope_response."
"MultiModalConversationResponse"
)
class DashScopeChatModel(ChatModelBase):
"""The DashScope chat model class, which unifies the Generation and
MultimodalConversation APIs into one method.
This class provides a unified interface for DashScope API by automatically
selecting between text-only (Generation API) and multimodal
(MultiModalConversation API) endpoints. The `multimodality` parameter
allows explicit control over API selection:
- When `multimodality=True`: Forces use of MultiModalConversation API
for handling images, videos, and other multimodal inputs
- When `multimodality=False`: Forces use of Generation API for
text-only processing
- When `multimodality=None` (default): Automatically selects the API
based on model name (e.g., models with "-vl" suffix or starting
with "qvq" will use MultiModalConversation API)
This design enables seamless switching between text and multimodal
models without changing code structure, making it easier to work with
DashScope's diverse model offerings.
"""
def __init__(
self,
model_name: str,
api_key: str,
stream: bool = True,
enable_thinking: bool | None = None,
multimodality: bool | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
base_http_api_url: str | None = None,
stream_tool_parsing: bool = True,
**_kwargs: Any,
) -> None:
"""Initialize the DashScope chat model.
Args:
model_name (`str`):
The model names.
api_key (`str`):
The dashscope API key.
stream (`bool`):
The streaming output or not
enable_thinking (`bool | None`, optional):
Enable thinking or not, only support Qwen3, QwQ, DeepSeek-R1.
Refer to `DashScope documentation
<https://help.aliyun.com/zh/model-studio/deep-thinking>`_
for more details.
multimodality (`bool | None`, optional):
Whether to use multimodal conversation API. If `True`,
it will use `dashscope.MultiModalConversation.call`
to process multimodal inputs such as images and text. If
`False`, it will use
`dashscope.aigc.generation.AioGeneration.call` to process
text inputs. If `None` (default), the choice is based on
the model name.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in DashScope API generation,
e.g. `temperature`, `seed`.
base_http_api_url (`str | None`, optional):
The base URL for DashScope API requests. If not provided,
the default base URL from the DashScope SDK will be used.
stream_tool_parsing (`bool`, default to `True`):
Whether to parse incomplete tool use JSON in streaming mode
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
is repaired to valid dicts (`{"a": "x"}`) in real-time for
immediate tool function input. Otherwise, the input field
remains {} until the final chunk arrives.
**_kwargs (`Any`):
Additional keyword arguments.
"""
if enable_thinking and not stream:
logger.info(
"In DashScope API, `stream` must be True when "
"`enable_thinking` is True. ",
)
stream = True
super().__init__(model_name, stream)
self.api_key = api_key
self.enable_thinking = enable_thinking
self.multimodality = multimodality
self.generate_kwargs = generate_kwargs or {}
self.stream_tool_parsing = stream_tool_parsing
if base_http_api_url is not None:
import dashscope
dashscope.base_http_api_url = base_http_api_url
# Load headers from environment variable if exists
headers = os.getenv("DASHSCOPE_API_HEADERS")
if headers:
try:
headers = json.loads(str(headers))
if not isinstance(headers, dict):
raise json.JSONDecodeError("", "", 0)
if self.generate_kwargs.get("headers"):
headers.update(self.generate_kwargs["headers"])
self.generate_kwargs["headers"] = headers
except json.JSONDecodeError:
logger.warning(
"Failed to parse DASHSCOPE_API_HEADERS environment "
"variable as JSON. It should be a JSON object.",
)
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from the dashscope
Generation/MultimodalConversation API by the given arguments.
.. note:: We unify the dashscope generation and multimodal conversation
APIs into one method, since they support similar arguments and share
the same functionality.
Args:
messages (`list[dict[str, Any]]`):
A list of dictionaries, where `role` and `content` fields are
required.
tools (`list[dict] | None`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool name.
Note: DashScope API only supports "auto" and "none", so
"required" will be converted to "auto".
For more details, please refer to
https://help.aliyun.com/zh/model-studio/qwen-function-calling
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
**kwargs (`Any`):
The keyword arguments for DashScope chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
refer to `DashScope documentation
<https://help.aliyun.com/zh/dashscope/developer-reference/api-details>`_
for more detailed arguments.
"""
import dashscope
kwargs = {
"messages": messages,
"model": self.model_name,
"stream": self.stream,
**self.generate_kwargs,
**kwargs,
"result_format": "message",
# In agentscope, the `incremental_output` must be `True` when
# `self.stream` is True
"incremental_output": self.stream,
}
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
# Handle deprecated "any" option with warning
if tool_choice in ["any", "required"]:
warnings.warn(
f"'{tool_choice}' is not supported by DashScope API. "
"It will be converted to 'auto'.",
DeprecationWarning,
)
tool_choice = "auto"
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if (
self.enable_thinking is not None
and "enable_thinking" not in kwargs
):
kwargs["enable_thinking"] = self.enable_thinking
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
format_tool = _create_tool_from_base_model(structured_model)
kwargs["tools"] = self._format_tools_json_schemas(
[format_tool],
)
kwargs["tool_choice"] = self._format_tool_choice(
format_tool["function"]["name"],
)
start_datetime = datetime.now()
if self.multimodality or (
self.multimodality is None
and (
self.model_name.startswith(
"qvq",
)
or "-vl" in self.model_name
)
):
response = dashscope.MultiModalConversation.call(
api_key=self.api_key,
**kwargs,
)
else:
response = await dashscope.aigc.generation.AioGeneration.call(
api_key=self.api_key,
**kwargs,
)
if self.stream:
return self._parse_dashscope_stream_response(
start_datetime,
response,
structured_model,
)
parsed_response = await self._parse_dashscope_generation_response(
start_datetime,
response,
structured_model,
)
return parsed_response
# pylint: disable=too-many-branches, too-many-statements
async def _parse_dashscope_stream_response(
self,
start_datetime: datetime,
response: Union[
AsyncGenerator[GenerationResponse, None],
Generator[MultiModalConversationResponse, None, None],
],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, Any]:
"""Given a DashScope streaming response generator, extract the content
blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (
`Union[AsyncGenerator[GenerationResponse, None], Generator[ \
MultiModalConversationResponse, None, None]]`
):
DashScope streaming response generator (GenerationResponse or
MultiModalConversationResponse) to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
AsyncGenerator[ChatResponse, Any]:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
acc_content, acc_thinking_content = "", ""
acc_tool_calls = collections.defaultdict(dict)
last_input_objs = {} # Store last input_obj for each tool_call
metadata = None
last_content = None
usage = None
async for chunk in giter(response):
if chunk.status_code != HTTPStatus.OK:
raise RuntimeError(
f"Failed to get response from _ API: {chunk}",
)
message = chunk.output.choices[0].message
# Update reasoning content
if isinstance(message.get("reasoning_content"), str):
acc_thinking_content += message["reasoning_content"]
# Update text content
if isinstance(message.content, str):
acc_content += message.content
elif isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and "text" in item:
acc_content += item["text"]
# Update tool calls
for tool_call in message.get("tool_calls", []):
index = tool_call.get("index", 0)
if "id" in tool_call and tool_call["id"] != acc_tool_calls[
index
].get("id"):
acc_tool_calls[index]["id"] = (
acc_tool_calls[index].get("id", "") + tool_call["id"]
)
if "function" in tool_call:
func = tool_call["function"]
if "name" in func:
acc_tool_calls[index]["name"] = (
acc_tool_calls[index].get("name", "")
+ func["name"]
)
if "arguments" in func:
acc_tool_calls[index]["arguments"] = (
acc_tool_calls[index].get("arguments", "")
+ func["arguments"]
)
# Build content blocks (always include thinking and text)
content_blocks: list[TextBlock | ToolUseBlock | ThinkingBlock] = []
if acc_thinking_content:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=acc_thinking_content,
),
)
if acc_content:
content_blocks.append(
TextBlock(
type="text",
text=acc_content,
),
)
for tool_call in acc_tool_calls.values():
# Only add intermediate tool use blocks if
# stream_tool_parsing is True
tool_id = tool_call.get("id", "")
input_str = tool_call.get("arguments")
# If parsing the tool input in streaming mode
if self.stream_tool_parsing:
repaired_input = _json_loads_with_repair(
input_str or "{}",
)
# If the new repaired input is shorter than one in the last
# chunk, use the last one to avoid regression
last_input = last_input_objs.get(tool_id, {})
if len(json.dumps(last_input)) > len(
json.dumps(repaired_input),
):
repaired_input = last_input
last_input_objs[tool_id] = repaired_input
else:
# Otherwise, keep input as empty dict until the final chunk
repaired_input = {}
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_id,
name=tool_call.get("name", ""),
input=repaired_input,
raw_input=input_str,
),
)
if structured_model:
metadata = repaired_input
if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.input_tokens,
output_tokens=chunk.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
metadata=chunk.usage,
)
if content_blocks:
parsed_chunk = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
yield parsed_chunk
last_content = copy.deepcopy(content_blocks)
# If stream_tool_parsing is False, we need to parse the final tool
# use inputs here
if not self.stream_tool_parsing and last_content and acc_tool_calls:
metadata = None
# Update tool use blocks in last_contents inplace
for block in last_content:
if block.get("type") == "tool_use":
block["input"] = input_obj = _json_loads_with_repair(
str(block.get("raw_input") or "{}"),
)
if structured_model:
metadata = input_obj
yield ChatResponse(
content=last_content,
usage=usage,
metadata=metadata,
)
async def _parse_dashscope_generation_response(
self,
start_datetime: datetime,
response: Union[
GenerationResponse,
MultiModalConversationResponse,
],
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given a DashScope GenerationResponse object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (
`Union[GenerationResponse, MultiModalConversationResponse]`
):
Dashscope GenerationResponse | MultiModalConversationResponse
object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
# Collect the content blocks from the response.
if response.status_code != 200:
raise RuntimeError(response)
content_blocks: List[TextBlock | ToolUseBlock] = []
metadata: dict | None = None
message = response.output.choices[0].message
content = message.get("content")
if response.output.choices[0].message.get("content") not in [
None,
"",
[],
]:
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and "text" in item:
content_blocks.append(
TextBlock(
type="text",
text=item["text"],
),
)
else:
content_blocks.append(
TextBlock(
type="text",
text=content,
),
)
if message.get("tool_calls"):
for tool_call in message["tool_calls"]:
input_ = _json_loads_with_repair(
tool_call["function"].get(
"arguments",
"{}",
)
or "{}",
)
content_blocks.append(
ToolUseBlock(
type="tool_use",
name=tool_call["function"]["name"],
input=input_,
id=tool_call["id"],
),
)
if structured_model:
metadata = input_
# Usage information
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
metadata=response.usage,
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schema into required format for DashScope API.
Args:
schemas (`dict[str, dict[str, Any]]`):
The tools JSON schemas.
"""
# Check schemas format
for value in schemas:
if (
not isinstance(value, dict)
or "type" not in value
or value["type"] != "function"
or "function" not in value
):
raise ValueError(
f"Each schema must be a dict with 'type' as 'function' "
f"and 'function' key, got {value}",
)
return schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "required"] | str | None,
) -> str | dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model. For more
details, please refer to
https://help.aliyun.com/zh/model-studio/qwen-function-calling
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
if tool_choice in ["auto", "none"]:
return tool_choice
if tool_choice == "required":
return "auto"
return {"type": "function", "function": {"name": tool_choice}}

Some files were not shown because too many files have changed in this diff Show More