chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
182
src/agentscope/__init__.py
Normal file
182
src/agentscope/__init__.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# flake8: noqa: E402
|
||||
# pylint: disable=wrong-import-position
|
||||
"""The agentscope serialization module"""
|
||||
import os
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import shortuuid
|
||||
|
||||
from ._run_config import _ConfigCls
|
||||
|
||||
|
||||
def _generate_random_suffix(length: int) -> str:
|
||||
"""Generate a random suffix."""
|
||||
return shortuuid.uuid()[:length]
|
||||
|
||||
|
||||
# A thread and async safe global configuration instance
|
||||
_config = _ConfigCls(
|
||||
run_id=ContextVar("run_id", default=shortuuid.uuid()),
|
||||
project=ContextVar(
|
||||
"project",
|
||||
default="UnnamedProject_At" + datetime.now().strftime("%Y%m%d"),
|
||||
),
|
||||
name=ContextVar(
|
||||
"name",
|
||||
default=datetime.now().strftime("%H%M%S_")
|
||||
+ _generate_random_suffix(4),
|
||||
),
|
||||
created_at=ContextVar(
|
||||
"created_at",
|
||||
default=datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3],
|
||||
),
|
||||
trace_enabled=ContextVar(
|
||||
"trace_enabled",
|
||||
default=False,
|
||||
),
|
||||
)
|
||||
|
||||
from . import exception
|
||||
from . import module
|
||||
from . import message
|
||||
from . import model
|
||||
from . import tool
|
||||
from . import formatter
|
||||
from . import memory
|
||||
from . import agent
|
||||
from . import session
|
||||
from . import embedding
|
||||
from . import token
|
||||
from . import evaluate
|
||||
from . import pipeline
|
||||
from . import tracing
|
||||
from . import rag
|
||||
from . import a2a
|
||||
from . import realtime
|
||||
|
||||
from ._logging import (
|
||||
logger,
|
||||
setup_logger,
|
||||
)
|
||||
from .hooks import _equip_as_studio_hooks
|
||||
from ._version import __version__
|
||||
|
||||
# Raise each warning only once
|
||||
warnings.filterwarnings("once", category=DeprecationWarning)
|
||||
|
||||
|
||||
def init(
|
||||
project: str | None = None,
|
||||
name: str | None = None,
|
||||
run_id: str | None = None,
|
||||
logging_path: str | None = None,
|
||||
logging_level: str = "INFO",
|
||||
studio_url: str | None = None,
|
||||
tracing_url: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the agentscope library.
|
||||
|
||||
Args:
|
||||
project (`str | None`, optional):
|
||||
The project name.
|
||||
name (`str | None`, optional):
|
||||
The name of the run.
|
||||
run_id (`str | None`, optional):
|
||||
The identity of a running instance, which can be an agent, or a
|
||||
multi-agent system. The `run_id` is used in AgentScope-Studio to
|
||||
distinguish different runs.
|
||||
logging_path (`str | None`, optional):
|
||||
The path to saving the log file. If not provided, logs will not be
|
||||
saved.
|
||||
logging_level (`str | None`, optional):
|
||||
The logging level. Defaults to "INFO".
|
||||
studio_url (`str | None`, optional):
|
||||
The URL of the AgentScope Studio to connect to.
|
||||
tracing_url (`str | None`, optional):
|
||||
The URL of the tracing endpoint, which can connect to third-party
|
||||
OpenTelemetry tracing platforms like Arize-Phoenix and Langfuse.
|
||||
If not provided and `studio_url` is provided, it will send traces
|
||||
to the AgentScope Studio's tracing endpoint.
|
||||
"""
|
||||
|
||||
if project:
|
||||
_config.project = project
|
||||
|
||||
if name:
|
||||
_config.name = name
|
||||
|
||||
if run_id:
|
||||
_config.run_id = run_id
|
||||
|
||||
setup_logger(logging_level, logging_path)
|
||||
|
||||
if studio_url:
|
||||
# Register the run
|
||||
data = {
|
||||
"id": _config.run_id,
|
||||
"project": _config.project,
|
||||
"name": _config.name,
|
||||
"timestamp": _config.created_at,
|
||||
"pid": os.getpid(),
|
||||
"status": "running",
|
||||
# Deprecated fields
|
||||
"run_dir": "",
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{studio_url}/trpc/registerRun",
|
||||
json=data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
from .agent import UserAgent, StudioUserInput
|
||||
|
||||
UserAgent.override_class_input_method(
|
||||
StudioUserInput(
|
||||
studio_url=studio_url,
|
||||
run_id=_config.run_id,
|
||||
max_retries=3,
|
||||
),
|
||||
)
|
||||
|
||||
_equip_as_studio_hooks(studio_url)
|
||||
|
||||
if tracing_url:
|
||||
endpoint = tracing_url
|
||||
else:
|
||||
endpoint = studio_url.strip("/") + "/v1/traces" if studio_url else None
|
||||
|
||||
if endpoint:
|
||||
from .tracing import setup_tracing
|
||||
|
||||
setup_tracing(endpoint=endpoint)
|
||||
_config.trace_enabled = True
|
||||
|
||||
|
||||
__all__ = [
|
||||
# modules
|
||||
"exception",
|
||||
"module",
|
||||
"message",
|
||||
"model",
|
||||
"tool",
|
||||
"formatter",
|
||||
"memory",
|
||||
"agent",
|
||||
"session",
|
||||
"logger",
|
||||
"embedding",
|
||||
"token",
|
||||
"evaluate",
|
||||
"pipeline",
|
||||
"tracing",
|
||||
"rag",
|
||||
"a2a",
|
||||
# functions
|
||||
"init",
|
||||
"setup_logger",
|
||||
"__version__",
|
||||
]
|
||||
47
src/agentscope/_logging.py
Normal file
47
src/agentscope/_logging.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The logger for agentscope."""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
_DEFAULT_FORMAT = (
|
||||
"%(asctime)s | %(levelname)-7s | "
|
||||
"%(module)s:%(funcName)s:%(lineno)s - %(message)s"
|
||||
)
|
||||
|
||||
logger = logging.getLogger("as")
|
||||
|
||||
|
||||
def setup_logger(
|
||||
level: str,
|
||||
filepath: str | None = None,
|
||||
) -> None:
|
||||
"""Set up the agentscope logger.
|
||||
|
||||
Args:
|
||||
level (`str`):
|
||||
The logging level, chosen from "INFO", "DEBUG", "WARNING",
|
||||
"ERROR", "CRITICAL".
|
||||
filepath (`str | None`, optional):
|
||||
The filepath to save the logging output.
|
||||
"""
|
||||
if level not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
|
||||
raise ValueError(
|
||||
f"Invalid logging level: {level}. Must be one of "
|
||||
f"'INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'.",
|
||||
)
|
||||
logger.handlers.clear()
|
||||
logger.setLevel(level)
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
|
||||
logger.addHandler(handler)
|
||||
|
||||
if filepath:
|
||||
handler = logging.FileHandler(filepath)
|
||||
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
|
||||
logger.addHandler(handler)
|
||||
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
setup_logger("INFO")
|
||||
73
src/agentscope/_run_config.py
Normal file
73
src/agentscope/_run_config.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The run instance configuration in agentscope."""
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
class _ConfigCls:
|
||||
"""The run instance configuration in agentscope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: ContextVar[str],
|
||||
project: ContextVar[str],
|
||||
name: ContextVar[str],
|
||||
created_at: ContextVar[str],
|
||||
trace_enabled: ContextVar[bool],
|
||||
) -> None:
|
||||
"""The constructor for _Config class."""
|
||||
# Copy the default context variables
|
||||
self._run_id = run_id
|
||||
self._created_at = created_at
|
||||
self._project = project
|
||||
self._name = name
|
||||
self._trace_enabled = trace_enabled
|
||||
|
||||
@property
|
||||
def run_id(self) -> str:
|
||||
"""Get the run ID."""
|
||||
return self._run_id.get()
|
||||
|
||||
@run_id.setter
|
||||
def run_id(self, value: str) -> None:
|
||||
"""Set the run ID."""
|
||||
self._run_id.set(value)
|
||||
|
||||
@property
|
||||
def created_at(self) -> str:
|
||||
"""Get the creation time."""
|
||||
return self._created_at.get()
|
||||
|
||||
@created_at.setter
|
||||
def created_at(self, value: str) -> None:
|
||||
"""Set the creation time."""
|
||||
self._created_at.set(value)
|
||||
|
||||
@property
|
||||
def project(self) -> str:
|
||||
"""Get the project name."""
|
||||
return self._project.get()
|
||||
|
||||
@project.setter
|
||||
def project(self, value: str) -> None:
|
||||
"""Set the project name."""
|
||||
self._project.set(value)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get the run name."""
|
||||
return self._name.get()
|
||||
|
||||
@name.setter
|
||||
def name(self, value: str) -> None:
|
||||
"""Set the run name."""
|
||||
self._name.set(value)
|
||||
|
||||
@property
|
||||
def trace_enabled(self) -> bool:
|
||||
"""Get whether tracing is enabled."""
|
||||
return self._trace_enabled.get()
|
||||
|
||||
@trace_enabled.setter
|
||||
def trace_enabled(self, value: bool) -> None:
|
||||
"""Set whether tracing is enabled."""
|
||||
self._trace_enabled.set(value)
|
||||
0
src/agentscope/_utils/__init__.py
Normal file
0
src/agentscope/_utils/__init__.py
Normal file
474
src/agentscope/_utils/_common.py
Normal file
474
src/agentscope/_utils/_common.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The common utilities for agentscope library."""
|
||||
import asyncio
|
||||
import base64
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import typing
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Type, Dict
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from docstring_parser import parse
|
||||
from json_repair import repair_json
|
||||
from pydantic import BaseModel, Field, create_model, ConfigDict
|
||||
|
||||
from .._logging import logger
|
||||
from ..types import ToolFunction
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from mcp.types import Tool
|
||||
else:
|
||||
Tool = "mcp.types.Tool"
|
||||
|
||||
|
||||
def _json_loads_with_repair(
|
||||
json_str: str,
|
||||
) -> dict:
|
||||
"""The given json_str maybe incomplete, e.g. '{"key', so we need to
|
||||
repair and load it into a Python object.
|
||||
|
||||
.. note::
|
||||
This function is currently only used for parsing the streaming output
|
||||
of the argument field in `tool_use`, so the parsed result must be a
|
||||
dict.
|
||||
|
||||
Args:
|
||||
json_str (`str`):
|
||||
The JSON string to parse, which may be incomplete or malformed.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
A dictionary parsed from the JSON string after repair attempts.
|
||||
Returns an empty dict if all repair attempts fail.
|
||||
"""
|
||||
try:
|
||||
repaired = repair_json(json_str, stream_stable=True)
|
||||
result = json.loads(repaired)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
if len(json_str) > 100:
|
||||
log_str = json_str[:100] + "..."
|
||||
else:
|
||||
log_str = json_str
|
||||
|
||||
logger.warning(
|
||||
"Failed to load JSON dict from string: %s. Returning empty dict "
|
||||
"instead.",
|
||||
log_str,
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _is_accessible_local_file(url: str) -> bool:
|
||||
"""Check if the given URL is a local URL."""
|
||||
return os.path.isfile(url)
|
||||
|
||||
|
||||
def _get_timestamp(add_random_suffix: bool = False) -> str:
|
||||
"""Get the current timestamp in the format YYYY-MM-DD HH:MM:SS.sss."""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
|
||||
if add_random_suffix:
|
||||
# Add a random suffix to the timestamp
|
||||
timestamp += f"_{os.urandom(3).hex()}"
|
||||
|
||||
return timestamp
|
||||
|
||||
|
||||
async def _is_async_func(func: Callable) -> bool:
|
||||
"""Check if the given function is an async function, including
|
||||
coroutine functions, async generators, and coroutine objects.
|
||||
"""
|
||||
|
||||
return (
|
||||
inspect.iscoroutinefunction(func)
|
||||
or inspect.isasyncgenfunction(func)
|
||||
or isinstance(func, types.CoroutineType)
|
||||
or isinstance(func, types.GeneratorType)
|
||||
and asyncio.iscoroutine(func)
|
||||
or isinstance(func, functools.partial)
|
||||
and await _is_async_func(func.func)
|
||||
)
|
||||
|
||||
|
||||
async def _execute_async_or_sync_func(
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute an async or sync function based on its type.
|
||||
|
||||
Args:
|
||||
func (`Callable`):
|
||||
The function to be executed, which can be either async or sync.
|
||||
*args (`Any`):
|
||||
Positional arguments to be passed to the function.
|
||||
**kwargs (`Any`):
|
||||
Keyword arguments to be passed to the function.
|
||||
|
||||
Returns:
|
||||
`Any`:
|
||||
The result of the function execution.
|
||||
"""
|
||||
|
||||
if await _is_async_func(func):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_bytes_from_web_url(
|
||||
url: str,
|
||||
max_retries: int = 3,
|
||||
) -> str:
|
||||
"""Get the bytes from a given URL.
|
||||
|
||||
Args:
|
||||
url (`str`):
|
||||
The URL to fetch the bytes from.
|
||||
max_retries (`int`, defaults to `3`):
|
||||
The maximum number of retries.
|
||||
"""
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content.decode("utf-8")
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return base64.b64encode(response.content).decode("ascii")
|
||||
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to fetch bytes from URL %s. Error %s. Retrying...",
|
||||
url,
|
||||
str(e),
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to fetch bytes from URL `{url}` after {max_retries} retries.",
|
||||
)
|
||||
|
||||
|
||||
def _save_base64_data(
|
||||
media_type: str,
|
||||
base64_data: str,
|
||||
) -> str:
|
||||
"""Save the base64 data to a temp file and return the file path. The
|
||||
extension is guessed from the MIME type.
|
||||
|
||||
Args:
|
||||
media_type (`str`):
|
||||
The MIME type of the data, e.g. "image/png", "audio/mpeg".
|
||||
base64_data (`str):
|
||||
The base64 data to be saved.
|
||||
"""
|
||||
extension = "." + media_type.split("/")[-1]
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=f".{extension}",
|
||||
delete=False,
|
||||
) as temp_file:
|
||||
decoded_data = base64.b64decode(base64_data)
|
||||
temp_file.write(decoded_data)
|
||||
temp_file.close()
|
||||
return temp_file.name
|
||||
|
||||
|
||||
def _extract_json_schema_from_mcp_tool(tool: Tool) -> dict[str, Any]:
|
||||
"""Extract JSON schema from MCP tool."""
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": tool.inputSchema.get(
|
||||
"properties",
|
||||
{},
|
||||
),
|
||||
"required": tool.inputSchema.get(
|
||||
"required",
|
||||
[],
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _remove_title_field(schema: dict) -> None:
|
||||
"""Remove the title field from the JSON schema to avoid
|
||||
misleading the LLM."""
|
||||
# The top level title field
|
||||
if "title" in schema:
|
||||
schema.pop("title")
|
||||
|
||||
# properties
|
||||
if "properties" in schema:
|
||||
for prop in schema["properties"].values():
|
||||
if isinstance(prop, dict):
|
||||
_remove_title_field(prop)
|
||||
|
||||
# items
|
||||
if "items" in schema and isinstance(schema["items"], dict):
|
||||
_remove_title_field(schema["items"])
|
||||
|
||||
# additionalProperties
|
||||
if "additionalProperties" in schema and isinstance(
|
||||
schema["additionalProperties"],
|
||||
dict,
|
||||
):
|
||||
_remove_title_field(
|
||||
schema["additionalProperties"],
|
||||
)
|
||||
|
||||
|
||||
def _create_tool_from_base_model(
|
||||
structured_model: Type[BaseModel],
|
||||
tool_name: str = "generate_structured_output",
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a function tool definition from a Pydantic BaseModel.
|
||||
This function converts a Pydantic BaseModel class into a tool definition
|
||||
that can be used with function calling API. The resulting tool
|
||||
definition includes the model's JSON schema as parameters, enabling
|
||||
structured output generation by forcing the model to call this function
|
||||
with properly formatted data.
|
||||
|
||||
Args:
|
||||
structured_model (`Type[BaseModel]`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the tool's output.
|
||||
tool_name (`str`, default `"generate_structured_output"`):
|
||||
The tool name that used to force the LLM to generate structured
|
||||
output by calling this function.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: A tool definition dictionary compatible with
|
||||
function calling API, containing type ("function") and
|
||||
function dictionary with name, description, and parameters
|
||||
(JSON schema).
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example usage
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
email: str
|
||||
|
||||
tool = _create_tool_from_base_model(PersonInfo, "extract_person")
|
||||
print(tool["function"]["name"]) # extract_person
|
||||
print(tool["type"]) # function
|
||||
|
||||
.. note:: The function automatically removes the 'title' field from
|
||||
the JSON schema to ensure compatibility with function calling
|
||||
format. This is handled by the internal ``_remove_title_field()``
|
||||
function.
|
||||
"""
|
||||
schema = structured_model.model_json_schema()
|
||||
|
||||
_remove_title_field(schema)
|
||||
tool_definition = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"description": "Generate the required structured output with "
|
||||
"this function",
|
||||
"parameters": schema,
|
||||
},
|
||||
}
|
||||
return tool_definition
|
||||
|
||||
|
||||
def _map_text_to_uuid(text: str) -> str:
|
||||
"""Map the given text to a deterministic UUID string.
|
||||
|
||||
Args:
|
||||
text (`str`):
|
||||
The input text to be mapped to a UUID.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
A deterministic UUID string derived from the input text.
|
||||
"""
|
||||
return str(uuid.uuid3(uuid.NAMESPACE_DNS, text))
|
||||
|
||||
|
||||
def _parse_tool_function(
|
||||
tool_func: ToolFunction,
|
||||
include_long_description: bool,
|
||||
include_var_positional: bool,
|
||||
include_var_keyword: bool,
|
||||
) -> dict:
|
||||
"""Extract JSON schema from the tool function's docstring
|
||||
|
||||
Args:
|
||||
tool_func (`ToolFunction`):
|
||||
The tool function to extract the JSON schema from.
|
||||
include_long_description (`bool`):
|
||||
Whether to include the long description in the JSON schema.
|
||||
include_var_positional (`bool`):
|
||||
Whether to include variable positional arguments in the JSON
|
||||
schema.
|
||||
include_var_keyword (`bool`):
|
||||
Whether to include variable keyword arguments in the JSON schema.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
The extracted JSON schema.
|
||||
"""
|
||||
docstring = parse(tool_func.__doc__)
|
||||
params_docstring = {_.arg_name: _.description for _ in docstring.params}
|
||||
|
||||
# Function description
|
||||
descriptions = []
|
||||
if docstring.short_description is not None:
|
||||
descriptions.append(docstring.short_description)
|
||||
|
||||
if include_long_description and docstring.long_description is not None:
|
||||
descriptions.append(docstring.long_description)
|
||||
|
||||
func_description = "\n".join(descriptions)
|
||||
|
||||
# Create a dynamic model with the function signature
|
||||
fields = {}
|
||||
for name, param in inspect.signature(tool_func).parameters.items():
|
||||
# Skip the `self` and `cls` parameters
|
||||
if name in ["self", "cls"]:
|
||||
continue
|
||||
|
||||
# Handle `**kwargs`
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
if not include_var_keyword:
|
||||
continue
|
||||
|
||||
fields[name] = (
|
||||
Dict[str, Any]
|
||||
if param.annotation == inspect.Parameter.empty
|
||||
else Dict[str, param.annotation], # type: ignore
|
||||
Field(
|
||||
description=params_docstring.get(
|
||||
f"**{name}",
|
||||
params_docstring.get(name, None),
|
||||
),
|
||||
default={}
|
||||
if param.default is param.empty
|
||||
else param.default,
|
||||
),
|
||||
)
|
||||
|
||||
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
|
||||
if not include_var_positional:
|
||||
continue
|
||||
|
||||
fields[name] = (
|
||||
list[Any]
|
||||
if param.annotation == inspect.Parameter.empty
|
||||
else list[param.annotation], # type: ignore
|
||||
Field(
|
||||
description=params_docstring.get(
|
||||
f"*{name}",
|
||||
params_docstring.get(name, None),
|
||||
),
|
||||
default=[]
|
||||
if param.default is param.empty
|
||||
else param.default,
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
fields[name] = (
|
||||
Any
|
||||
if param.annotation == inspect.Parameter.empty
|
||||
else param.annotation,
|
||||
Field(
|
||||
description=params_docstring.get(name, None),
|
||||
default=...
|
||||
if param.default is param.empty
|
||||
else param.default,
|
||||
),
|
||||
)
|
||||
|
||||
base_model = create_model(
|
||||
"_StructuredOutputDynamicClass",
|
||||
__config__=ConfigDict(arbitrary_types_allowed=True),
|
||||
**fields,
|
||||
)
|
||||
params_json_schema = base_model.model_json_schema()
|
||||
|
||||
# Remove the title from the json schema
|
||||
_remove_title_field(params_json_schema)
|
||||
|
||||
func_json_schema: dict = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_func.__name__,
|
||||
"parameters": params_json_schema,
|
||||
},
|
||||
}
|
||||
|
||||
if func_description not in [None, ""]:
|
||||
func_json_schema["function"]["description"] = func_description
|
||||
|
||||
return func_json_schema
|
||||
|
||||
|
||||
def _resample_pcm_delta(
|
||||
pcm_base64: str,
|
||||
sample_rate: int,
|
||||
target_rate: int,
|
||||
) -> str:
|
||||
"""Resampling the input pcm base64 data into the target rate.
|
||||
|
||||
Args:
|
||||
pcm_base64 (`str`):
|
||||
The input base64 audio data in pcm format.
|
||||
sample_rate (`int`):
|
||||
The sampling rate of the input data.
|
||||
target_rate (`int`):
|
||||
The target rate of the input data.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The resampling base64 audio data in the required sampling
|
||||
rate.
|
||||
"""
|
||||
pcm_data = base64.b64decode(pcm_base64)
|
||||
|
||||
# Into numpy array first
|
||||
audio_array = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
|
||||
# return directly if the same
|
||||
if sample_rate == target_rate:
|
||||
return pcm_base64
|
||||
|
||||
# compute the number of samples
|
||||
num_samples = int(len(audio_array) * target_rate / sample_rate)
|
||||
|
||||
from scipy import signal
|
||||
|
||||
# Use scipy to resample
|
||||
resampled_audio = signal.resample(audio_array, num_samples)
|
||||
|
||||
# Turn it back into bytes
|
||||
resampled_audio = np.clip(resampled_audio, -32768, 32767).astype(np.int16)
|
||||
|
||||
# into base64
|
||||
resampled_bytes = resampled_audio.tobytes()
|
||||
resampled_base64 = base64.b64encode(resampled_bytes).decode("utf-8")
|
||||
|
||||
return resampled_base64
|
||||
9
src/agentscope/_utils/_mixin.py
Normal file
9
src/agentscope/_utils/_mixin.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The mixin for agentscope."""
|
||||
|
||||
|
||||
class DictMixin(dict):
|
||||
"""The dictionary mixin that allows attribute-style access."""
|
||||
|
||||
__setattr__ = dict.__setitem__
|
||||
__getattr__ = dict.__getitem__
|
||||
4
src/agentscope/_version.py
Normal file
4
src/agentscope/_version.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The version of agentscope."""
|
||||
|
||||
__version__ = "1.0.16"
|
||||
14
src/agentscope/a2a/__init__.py
Normal file
14
src/agentscope/a2a/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The A2A related modules."""
|
||||
from ._base import AgentCardResolverBase
|
||||
from ._file_resolver import FileAgentCardResolver
|
||||
from ._well_known_resolver import WellKnownAgentCardResolver
|
||||
from ._nacos_resolver import NacosAgentCardResolver
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentCardResolverBase",
|
||||
"FileAgentCardResolver",
|
||||
"WellKnownAgentCardResolver",
|
||||
"NacosAgentCardResolver",
|
||||
]
|
||||
25
src/agentscope/a2a/_base.py
Normal file
25
src/agentscope/a2a/_base.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The A2A agent card resolver base class."""
|
||||
from abc import abstractmethod
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
else:
|
||||
AgentCard = "a2a.types.AgentCard"
|
||||
|
||||
|
||||
class AgentCardResolverBase:
|
||||
"""Base class for A2A agent card resolvers, responsible for fetching
|
||||
agent cards from various sources. Implementations must provide the
|
||||
`get_agent_card` method to retrieve the agent card.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_agent_card(self, *args: Any, **kwargs: Any) -> AgentCard:
|
||||
"""Get Agent Card from the configured source.
|
||||
|
||||
Returns:
|
||||
`AgentCard`:
|
||||
The resolved agent card object.
|
||||
"""
|
||||
78
src/agentscope/a2a/_file_resolver.py
Normal file
78
src/agentscope/a2a/_file_resolver.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The JSON file based A2A agent card resolver."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ._base import AgentCardResolverBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
else:
|
||||
AgentCard = "a2a.types.AgentCard"
|
||||
|
||||
|
||||
class FileAgentCardResolver(AgentCardResolverBase):
|
||||
"""Agent card resolver that loads AgentCard from a JSON file.
|
||||
|
||||
The JSON file should contain an AgentCard object with the following
|
||||
required fields:
|
||||
|
||||
- name (str): The name of the agent
|
||||
- url (str): The URL of the agent
|
||||
- version (str): The version of the agent
|
||||
- capabilities (dict): The capabilities of the agent
|
||||
- default_input_modes (list[str]): Default input modes
|
||||
- default_output_modes (list[str]): Default output modes
|
||||
- skills (list): List of agent skills
|
||||
|
||||
Example JSON file content:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"name": "RemoteAgent",
|
||||
"url": "http://localhost:8000",
|
||||
"description": "A remote A2A agent",
|
||||
"version": "1.0.0",
|
||||
"capabilities": {},
|
||||
"default_input_modes": ["text/plain"],
|
||||
"default_output_modes": ["text/plain"],
|
||||
"skills": []
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
) -> None:
|
||||
"""Initialize the FileAgentCardResolver with the path to the JSON file.
|
||||
|
||||
Args:
|
||||
file_path (`str`):
|
||||
The path to the JSON file containing the agent card.
|
||||
"""
|
||||
self._file_path = file_path
|
||||
|
||||
async def get_agent_card(self) -> AgentCard:
|
||||
"""Get the agent card from the JSON file.
|
||||
|
||||
Returns:
|
||||
`AgentCard`:
|
||||
The agent card loaded from the file.
|
||||
"""
|
||||
from a2a.types import AgentCard
|
||||
|
||||
path = Path(self._file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Agent card file not found: {self._file_path}",
|
||||
)
|
||||
|
||||
if not path.is_file():
|
||||
raise ValueError(f"Path is not a file: {self._file_path}")
|
||||
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
agent_json_data = json.load(f)
|
||||
return AgentCard.model_validate(agent_json_data)
|
||||
98
src/agentscope/a2a/_nacos_resolver.py
Normal file
98
src/agentscope/a2a/_nacos_resolver.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The Nacos-based A2A Agent Card resolver."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ._base import AgentCardResolverBase
|
||||
from .._logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
from v2.nacos.common.client_config import ClientConfig
|
||||
else:
|
||||
AgentCard = "a2a.types.AgentCard"
|
||||
ClientConfig = "v2.nacos.common.client_config.ClientConfig"
|
||||
|
||||
|
||||
class NacosAgentCardResolver(AgentCardResolverBase):
|
||||
"""Nacos-based A2A Agent Card resolver.
|
||||
|
||||
Nacos is a dynamic service discovery, configuration and service
|
||||
management platform for building cloud native applications. This resolver
|
||||
fetches the agent card from a Nacos server and subscribes to updates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remote_agent_name: str,
|
||||
nacos_client_config: ClientConfig,
|
||||
version: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the nacos agent card resolver.
|
||||
|
||||
Args:
|
||||
remote_agent_name (`str`):
|
||||
Name of the remote agent in Nacos.
|
||||
nacos_client_config (`ClientConfig | None`, optional):
|
||||
Nacos client configuration, where a `server_addresses`
|
||||
parameter is required.
|
||||
version (`str | None`, optional):
|
||||
Version of the agent card to fetch. If None, fetches the
|
||||
latest version. This version is also used when subscribing
|
||||
to agent card updates.
|
||||
Defaults to None (latest version).
|
||||
"""
|
||||
if not remote_agent_name:
|
||||
raise ValueError(
|
||||
"The remote_agent_name cannot be empty.",
|
||||
)
|
||||
|
||||
if not nacos_client_config:
|
||||
raise ValueError(
|
||||
"The nacos_client_config cannot be None.",
|
||||
)
|
||||
|
||||
self._nacos_client_config = nacos_client_config
|
||||
self._remote_agent_name = remote_agent_name
|
||||
self._version = version
|
||||
|
||||
async def get_agent_card(self) -> AgentCard:
|
||||
"""Get agent card from Nacos with lazy initialization.
|
||||
|
||||
Returns:
|
||||
`AgentCard`:
|
||||
The resolved agent card from Nacos.
|
||||
"""
|
||||
try:
|
||||
from v2.nacos.ai.model.ai_param import GetAgentCardParam
|
||||
from v2.nacos.ai.nacos_ai_service import NacosAIService
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install the nacos sdk by running `pip install "
|
||||
"nacos-sdk-python>=3.0.0` first.",
|
||||
) from e
|
||||
|
||||
client = None
|
||||
try:
|
||||
client = await NacosAIService.create_ai_service(
|
||||
self._nacos_client_config,
|
||||
)
|
||||
|
||||
await client.start()
|
||||
return await client.get_agent_card(
|
||||
GetAgentCardParam(
|
||||
agent_name=self._remote_agent_name,
|
||||
version=self._version,
|
||||
),
|
||||
)
|
||||
|
||||
finally:
|
||||
if client:
|
||||
# Close the Nacos client to free resources
|
||||
try:
|
||||
await client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to shutdown Nacos client: %s",
|
||||
str(e),
|
||||
)
|
||||
90
src/agentscope/a2a/_well_known_resolver.py
Normal file
90
src/agentscope/a2a/_well_known_resolver.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The A2A well-known agent card resolver."""
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._base import AgentCardResolverBase
|
||||
from .._logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
else:
|
||||
AgentCard = "a2a.types.AgentCard"
|
||||
|
||||
|
||||
class WellKnownAgentCardResolver(AgentCardResolverBase):
|
||||
"""Agent card resolver that loads AgentCard from a well-known URL."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
agent_card_path: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the WellKnownAgentCardResolver.
|
||||
|
||||
Args:
|
||||
base_url (`str`):
|
||||
The base URL to resolve the agent card from.
|
||||
agent_card_path (`str | None`, optional):
|
||||
The path to the agent card relative to the base URL.
|
||||
Defaults to AGENT_CARD_WELL_KNOWN_PATH from a2a.utils.
|
||||
"""
|
||||
self._base_url = base_url
|
||||
self._agent_card_path = agent_card_path
|
||||
|
||||
async def get_agent_card(self) -> AgentCard:
|
||||
"""Get the agent card from the well-known URL.
|
||||
|
||||
Returns:
|
||||
`AgentCard`:
|
||||
The agent card loaded from the URL.
|
||||
"""
|
||||
import httpx
|
||||
from a2a.client import A2ACardResolver
|
||||
from a2a.utils import AGENT_CARD_WELL_KNOWN_PATH
|
||||
|
||||
try:
|
||||
parsed_url = urlparse(self._base_url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
logger.error(
|
||||
"[%s] Invalid URL format: %s",
|
||||
self.__class__.__name__,
|
||||
self._base_url,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid URL format: {self._base_url}",
|
||||
)
|
||||
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
relative_card_path = parsed_url.path
|
||||
|
||||
# Use default path if not specified
|
||||
agent_card_path = (
|
||||
self._agent_card_path
|
||||
if self._agent_card_path is not None
|
||||
else AGENT_CARD_WELL_KNOWN_PATH
|
||||
)
|
||||
|
||||
# Use async context manager to ensure proper cleanup
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout=600),
|
||||
) as _http_client:
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=_http_client,
|
||||
base_url=base_url,
|
||||
agent_card_path=agent_card_path,
|
||||
)
|
||||
return await resolver.get_agent_card(
|
||||
relative_card_path=relative_card_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[%s] Failed to resolve agent card from URL %s: %s",
|
||||
self.__class__.__name__,
|
||||
self._base_url,
|
||||
e,
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to resolve AgentCard from URL "
|
||||
f"{self._base_url}: {e}",
|
||||
) from e
|
||||
28
src/agentscope/agent/__init__.py
Normal file
28
src/agentscope/agent/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The agent base class."""
|
||||
from ._agent_base import AgentBase
|
||||
from ._react_agent_base import ReActAgentBase
|
||||
from ._react_agent import ReActAgent
|
||||
from ._user_input import (
|
||||
UserInputBase,
|
||||
UserInputData,
|
||||
TerminalUserInput,
|
||||
StudioUserInput,
|
||||
)
|
||||
from ._user_agent import UserAgent
|
||||
from ._a2a_agent import A2AAgent
|
||||
from ._realtime_agent import RealtimeAgent
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentBase",
|
||||
"ReActAgentBase",
|
||||
"ReActAgent",
|
||||
"UserInputData",
|
||||
"UserInputBase",
|
||||
"TerminalUserInput",
|
||||
"StudioUserInput",
|
||||
"UserAgent",
|
||||
"A2AAgent",
|
||||
"RealtimeAgent",
|
||||
]
|
||||
288
src/agentscope/agent/_a2a_agent.py
Normal file
288
src/agentscope/agent/_a2a_agent.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A2A agent implementation for AgentScope.
|
||||
|
||||
This module provides the A2A (Agent-to-Agent) protocol implementation,
|
||||
enabling AgentScope agents to communicate with remote agents using the
|
||||
A2A standard protocol.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agent_base import AgentBase
|
||||
from ..message import Msg
|
||||
from ..formatter import A2AChatFormatter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
from a2a.client import ClientConfig, Consumer
|
||||
from a2a.client.client_factory import TransportProducer
|
||||
else:
|
||||
AgentCard = "a2a.types.AgentCard"
|
||||
ClientConfig = "a2a.client.ClientConfig"
|
||||
Consumer = "a2a.client.Consumer"
|
||||
TransportProducer = "a2a.client.client_factory.TransportProducer"
|
||||
|
||||
|
||||
class A2AAgent(AgentBase):
|
||||
"""An A2A agent implementation in AgentScope, which supports
|
||||
|
||||
- Communication with remote agents using the A2A protocol
|
||||
- Bidirectional message conversion between AgentScope and A2A formats
|
||||
- Task lifecycle management with streaming and polling
|
||||
- Artifact handling and status tracking
|
||||
|
||||
.. note:: Due to the limitation of A2A protocol. The A2AAgent class
|
||||
|
||||
- Only support chatbot-scenario (a user and an assistant) interactions.
|
||||
To support multi-agent interactions requires the server side to handle
|
||||
the `name` field in the A2A messages properly.
|
||||
- Does not support structured output in `reply()` method due to the lack
|
||||
of structured output support in A2A protocol.
|
||||
- Stores observed messages locally and merges them with input messages
|
||||
when `reply()` is called. Observed messages are cleared after processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_card: AgentCard,
|
||||
client_config: ClientConfig | None = None,
|
||||
consumers: list[Consumer] | None = None,
|
||||
additional_transport_producers: dict[str, TransportProducer]
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""Initialize the A2A agent instance by the given agent card.
|
||||
|
||||
Args:
|
||||
agent_card (`AgentCard`):
|
||||
The agent card that contains the information about the remote
|
||||
agent, such as its URL and capabilities.
|
||||
client_config (`ClientConfig | None`, optional):
|
||||
The configuration for the A2A client, including transport
|
||||
preferences and streaming options.
|
||||
consumers (`list[Consumer] | None`, optional):
|
||||
The list of consumers for handling A2A client events.
|
||||
These intercept request/response flows for logging,
|
||||
metrics, and security.
|
||||
additional_transport_producers (`dict[str, TransportProducer] | \
|
||||
None`, optional):
|
||||
Additional transport producers for creating A2A clients
|
||||
with specific transport protocols.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
from a2a.types import AgentCard
|
||||
from a2a.client import ClientConfig, ClientFactory
|
||||
|
||||
if not isinstance(agent_card, AgentCard):
|
||||
raise ValueError(
|
||||
f"agent_card must be an instance of AgentCard, "
|
||||
f"got {type(agent_card)}",
|
||||
)
|
||||
|
||||
self.name: str = agent_card.name
|
||||
self.agent_card = agent_card
|
||||
|
||||
# Create the client factory so that we can create clients later
|
||||
# in reply()
|
||||
self._a2a_client_factory = ClientFactory(
|
||||
config=client_config
|
||||
or ClientConfig(
|
||||
httpx_client=httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout=600),
|
||||
),
|
||||
),
|
||||
consumers=consumers,
|
||||
)
|
||||
|
||||
# Register additional transport producers if provided
|
||||
if additional_transport_producers:
|
||||
for label, producer in additional_transport_producers.items():
|
||||
self._a2a_client_factory.register(
|
||||
label,
|
||||
producer,
|
||||
)
|
||||
|
||||
# The variables to store observed messages
|
||||
self._observed_msgs: list[Msg] = []
|
||||
|
||||
# The formatter used for message conversion
|
||||
self.formatter = A2AChatFormatter()
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary of the A2A agent.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
The state dictionary containing the observed messages.
|
||||
"""
|
||||
return {
|
||||
"_observed_msgs": [msg.to_dict() for msg in self._observed_msgs],
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
|
||||
"""Load the state dictionary into the module.
|
||||
|
||||
Args:
|
||||
state_dict (`dict`):
|
||||
The state dictionary to load.
|
||||
strict (`bool`, defaults to `True`):
|
||||
If `True`, raises an error if any key in the module is not
|
||||
found in the state_dict. If `False`, skips missing keys.
|
||||
|
||||
Raises:
|
||||
`KeyError`:
|
||||
If a required key is missing in the state_dict when strict
|
||||
is `True`.
|
||||
"""
|
||||
if "_observed_msgs" in state_dict:
|
||||
self._observed_msgs = [
|
||||
Msg.from_dict(d) for d in state_dict["_observed_msgs"]
|
||||
]
|
||||
else:
|
||||
raise KeyError(
|
||||
"_observed_msgs key not found in state_dict",
|
||||
)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key != "_observed_msgs":
|
||||
raise KeyError(f"Unexpected key {key} in state_dict")
|
||||
|
||||
async def observe(self, msg: Msg | list[Msg] | None) -> None:
|
||||
"""Receive the given message(s) without generating a reply.
|
||||
|
||||
The observed messages are stored and will be merged with the
|
||||
input messages when `reply` is called. After `reply` completes,
|
||||
the stored messages will be cleared.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message(s) to be observed. If None, no action is taken.
|
||||
"""
|
||||
if msg is None:
|
||||
return
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
self._observed_msgs.append(msg)
|
||||
elif isinstance(msg, list) and all(isinstance(m, Msg) for m in msg):
|
||||
self._observed_msgs.extend(msg)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"msg must be a Msg or a list of Msg, got {type(msg)}",
|
||||
)
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Msg:
|
||||
"""Send message(s) to the remote A2A agent and receive a response.
|
||||
|
||||
.. note:: This method merges any previously observed messages with the
|
||||
input messages, sends them to the remote agent, and clears the
|
||||
observed messages after processing.
|
||||
|
||||
.. note:: The A2A protocol does not support structured output, so the
|
||||
`structured_model` parameter is not supported in this method.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`, optional):
|
||||
The message(s) to send to the remote agent. Can be a single
|
||||
Msg, a list of Msgs, or None. If None, only observed messages
|
||||
will be sent. Defaults to None.
|
||||
|
||||
Returns:
|
||||
`Msg`:
|
||||
The response message from the remote agent. For tasks, this
|
||||
may be either a status update message or the final artifacts
|
||||
message, depending on the task state. If no messages are
|
||||
provided (both msg and observed messages are empty), returns
|
||||
a prompt message. If an error occurs during communication,
|
||||
returns an error message.
|
||||
"""
|
||||
if "structured_model" in kwargs:
|
||||
raise ValueError(
|
||||
"structured_model is not supported in A2AAgent.reply() "
|
||||
"due to the lack of structured output support in A2A "
|
||||
"protocol.",
|
||||
)
|
||||
|
||||
from a2a.types import Message as A2AMessage
|
||||
|
||||
# Merge observed messages with input messages
|
||||
msgs_list = self._observed_msgs
|
||||
|
||||
if msg is not None:
|
||||
if isinstance(msg, Msg):
|
||||
msgs_list.append(msg)
|
||||
else:
|
||||
msgs_list.extend(msg)
|
||||
|
||||
# Create A2A client and send message
|
||||
client = self._a2a_client_factory.create(
|
||||
card=self.agent_card,
|
||||
)
|
||||
|
||||
# Convert Msg objects into A2A Message object
|
||||
a2a_message = await self.formatter.format([_ for _ in msgs_list if _])
|
||||
|
||||
response_msg = None
|
||||
async for item in client.send_message(a2a_message):
|
||||
if isinstance(item, A2AMessage):
|
||||
response_msg = await self.formatter.format_a2a_message(
|
||||
self.name,
|
||||
item,
|
||||
)
|
||||
await self.print(response_msg)
|
||||
|
||||
elif isinstance(item, tuple):
|
||||
task, _ = item
|
||||
|
||||
if task is not None:
|
||||
for _ in await self.formatter.format_a2a_task(
|
||||
self.name,
|
||||
task,
|
||||
):
|
||||
await self.print(_)
|
||||
response_msg = _
|
||||
|
||||
# Clear the observed messages after processing
|
||||
self._observed_msgs.clear()
|
||||
|
||||
if response_msg:
|
||||
return response_msg
|
||||
|
||||
raise ValueError(
|
||||
"No response received from remote agent",
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
async def handle_interrupt(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> Msg:
|
||||
"""The post-processing logic when the reply is interrupted by the
|
||||
user or something else.
|
||||
"""
|
||||
|
||||
response_msg = Msg(
|
||||
self.name,
|
||||
"I noticed that you have interrupted me. What can I "
|
||||
"do for you?",
|
||||
"assistant",
|
||||
metadata={
|
||||
# Expose this field to indicate the interruption
|
||||
"_is_interrupted": True,
|
||||
},
|
||||
)
|
||||
|
||||
await self.print(response_msg, True)
|
||||
|
||||
# Add to observed messages for context in next reply
|
||||
self._observed_msgs.append(response_msg)
|
||||
|
||||
return response_msg
|
||||
736
src/agentscope/agent/_agent_base.py
Normal file
736
src/agentscope/agent/_agent_base.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The agent base class in agentscope."""
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from asyncio import Task, Queue
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Any
|
||||
import base64
|
||||
import shortuuid
|
||||
import numpy as np
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from ._agent_meta import _AgentMeta
|
||||
from .._logging import logger
|
||||
from ..module import StateModule
|
||||
from ..message import (
|
||||
Msg,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
)
|
||||
from ..types import AgentHookTypes
|
||||
|
||||
|
||||
class AgentBase(StateModule, metaclass=_AgentMeta):
|
||||
"""Base class for asynchronous agents."""
|
||||
|
||||
id: str
|
||||
"""The agent's unique identifier, generated using shortuuid."""
|
||||
|
||||
supported_hook_types: list[str] = [
|
||||
"pre_reply",
|
||||
"post_reply",
|
||||
"pre_print",
|
||||
"post_print",
|
||||
"pre_observe",
|
||||
"post_observe",
|
||||
]
|
||||
"""Supported hook types for the agent base class."""
|
||||
|
||||
_class_pre_reply_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"AgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
],
|
||||
dict[str, Any] | None, # The modified kwargs or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level hook functions that will be called before the reply
|
||||
function, taking `self` object, the input arguments as input, and
|
||||
generating the modified arguments (if needed). Then input arguments of the
|
||||
reply function will be re-organized into a keyword arguments dictionary.
|
||||
If the one hook returns a new dictionary, the modified arguments will be
|
||||
passed to the next hook or the original reply function."""
|
||||
|
||||
_class_post_reply_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"AgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
Msg, # output, the output message
|
||||
],
|
||||
Msg | None,
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level hook functions that will be called after the reply
|
||||
function, which takes the `self` object and deep copied
|
||||
positional and keyword arguments (args and kwargs), and the output message
|
||||
as input. If the hook returns a message, the new message will be passed
|
||||
to the next hook or the original reply function. Otherwise, the original
|
||||
output will be passed instead."""
|
||||
|
||||
_class_pre_print_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"AgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
],
|
||||
dict[str, Any] | None, # The modified kwargs or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level hook functions that will be called before printing,
|
||||
which takes the `self` object, a deep copied arguments dictionary as input,
|
||||
and output the modified arguments (if needed). """
|
||||
|
||||
_class_post_print_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"AgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
Any, # output, `None` if no output
|
||||
],
|
||||
Any,
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level hook functions that will be called after the speak
|
||||
function, which takes the `self` object as input."""
|
||||
|
||||
_class_pre_observe_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"AgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
],
|
||||
dict[str, Any] | None, # The modified kwargs or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level hook functions that will be called before the observe
|
||||
function, which takes the `self` object and a deep copied input
|
||||
arguments dictionary as input. To change the input arguments, the hook
|
||||
function needs to output the modified arguments dictionary, which will be
|
||||
used as the input of the next hook function or the original observe
|
||||
function."""
|
||||
|
||||
_class_post_observe_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"AgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
None, # The output, `None` if no output
|
||||
],
|
||||
None,
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level hook functions that will be called after the observe
|
||||
function, which takes the `self` object as input."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the agent."""
|
||||
super().__init__()
|
||||
|
||||
self.id = shortuuid.uuid()
|
||||
|
||||
# The replying task and identify of the current replying
|
||||
self._reply_task: Task | None = None
|
||||
self._reply_id: str | None = None
|
||||
|
||||
# Initialize the instance-level hooks
|
||||
self._instance_pre_print_hooks = OrderedDict()
|
||||
self._instance_post_print_hooks = OrderedDict()
|
||||
|
||||
self._instance_pre_reply_hooks = OrderedDict()
|
||||
self._instance_post_reply_hooks = OrderedDict()
|
||||
|
||||
self._instance_pre_observe_hooks = OrderedDict()
|
||||
self._instance_post_observe_hooks = OrderedDict()
|
||||
|
||||
# The prefix used in streaming printing, which will save the
|
||||
# accumulated text and audio streaming data for each message id.
|
||||
# e.g. {"text": "xxx", "audio": (stream_obj, "{base64_data}")}
|
||||
self._stream_prefix = {}
|
||||
|
||||
# The subscribers that will receive the reply message by their
|
||||
# `observe` method. The key is the MsgHub id, and the value is the
|
||||
# list of agents.
|
||||
self._subscribers: dict[str, list[AgentBase]] = {}
|
||||
|
||||
# We add this variable in case developers want to disable the console
|
||||
# output of the agent, e.g., in a production environment.
|
||||
self._disable_console_output: bool = (
|
||||
os.getenv(
|
||||
"AGENTSCOPE_DISABLE_CONSOLE_OUTPUT",
|
||||
"false",
|
||||
).lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
# The streaming message queue used to export the messages as a
|
||||
# generator
|
||||
self._disable_msg_queue: bool = True
|
||||
self.msg_queue = None
|
||||
|
||||
async def observe(self, msg: Msg | list[Msg] | None) -> None:
|
||||
"""Receive the given message(s) without generating a reply.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message(s) to be observed.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"The observe function is not implemented in"
|
||||
f" {self.__class__.__name__} class.",
|
||||
)
|
||||
|
||||
async def reply(self, *args: Any, **kwargs: Any) -> Msg:
|
||||
"""The main logic of the agent, which generates a reply based on the
|
||||
current state and input arguments."""
|
||||
raise NotImplementedError(
|
||||
"The reply function is not implemented in "
|
||||
f"{self.__class__.__name__} class.",
|
||||
)
|
||||
|
||||
async def print(
|
||||
self,
|
||||
msg: Msg,
|
||||
last: bool = True,
|
||||
speech: AudioBlock | list[AudioBlock] | None = None,
|
||||
) -> None:
|
||||
"""The function to display the message.
|
||||
|
||||
Args:
|
||||
msg (`Msg`):
|
||||
The message object to be printed.
|
||||
last (`bool`, defaults to `True`):
|
||||
Whether this is the last one in streaming messages. For
|
||||
non-streaming message, this should always be `True`.
|
||||
speech (`AudioBlock | list[AudioBlock] | None`, optional):
|
||||
The audio content block(s) to be played along with the
|
||||
message.
|
||||
"""
|
||||
if not self._disable_msg_queue:
|
||||
await self.msg_queue.put((deepcopy(msg), last, speech))
|
||||
# Yield control to the event loop, allowing consumer coroutines
|
||||
# to process messages from the queue. This prevents the producer
|
||||
# from monopolizing the event loop.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._disable_console_output:
|
||||
return
|
||||
|
||||
# The accumulated textual content to print, including the text blocks
|
||||
# and the thinking blocks
|
||||
thinking_and_text_to_print = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
self._print_text_block(
|
||||
msg.id,
|
||||
name_prefix=msg.name,
|
||||
text_content=block["text"],
|
||||
thinking_and_text_to_print=thinking_and_text_to_print,
|
||||
)
|
||||
|
||||
elif block["type"] == "thinking":
|
||||
self._print_text_block(
|
||||
msg.id,
|
||||
name_prefix=f"{msg.name}(thinking)",
|
||||
text_content=block["thinking"],
|
||||
thinking_and_text_to_print=thinking_and_text_to_print,
|
||||
)
|
||||
|
||||
elif last:
|
||||
self._print_last_block(block, msg)
|
||||
|
||||
# Play audio block if exists
|
||||
if isinstance(speech, list):
|
||||
for audio_block in speech:
|
||||
self._process_audio_block(msg.id, audio_block)
|
||||
elif isinstance(speech, dict):
|
||||
self._process_audio_block(msg.id, speech)
|
||||
|
||||
# Clean up resources if this is the last message in streaming
|
||||
if last and msg.id in self._stream_prefix:
|
||||
if "audio" in self._stream_prefix[msg.id]:
|
||||
player, _ = self._stream_prefix[msg.id]["audio"]
|
||||
# Close the miniaudio player
|
||||
player.close()
|
||||
stream_prefix = self._stream_prefix.pop(msg.id)
|
||||
if "text" in stream_prefix and not stream_prefix["text"].endswith(
|
||||
"\n",
|
||||
):
|
||||
print()
|
||||
|
||||
def _process_audio_block(
|
||||
self,
|
||||
msg_id: str,
|
||||
audio_block: AudioBlock,
|
||||
) -> None:
|
||||
"""Process audio block content.
|
||||
|
||||
Args:
|
||||
msg_id (`str`):
|
||||
The unique identifier of the message
|
||||
audio_block (`AudioBlock`):
|
||||
The audio content block
|
||||
"""
|
||||
if "source" not in audio_block:
|
||||
raise ValueError(
|
||||
"The audio block must contain the 'source' field.",
|
||||
)
|
||||
|
||||
if audio_block["source"]["type"] == "url":
|
||||
import urllib.request
|
||||
import wave
|
||||
import sounddevice as sd
|
||||
|
||||
url = audio_block["source"]["url"]
|
||||
try:
|
||||
with urllib.request.urlopen(url) as response:
|
||||
audio_data = response.read()
|
||||
|
||||
with wave.open(io.BytesIO(audio_data), "rb") as wf:
|
||||
samplerate = wf.getframerate()
|
||||
n_frames = wf.getnframes()
|
||||
audio_frames = wf.readframes(n_frames)
|
||||
|
||||
# Convert byte data to numpy array
|
||||
audio_np = np.frombuffer(audio_frames, dtype=np.int16)
|
||||
|
||||
# Play audio
|
||||
sd.play(audio_np, samplerate)
|
||||
sd.wait()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to play audio from url %s: %s",
|
||||
url,
|
||||
str(e),
|
||||
)
|
||||
|
||||
elif audio_block["source"]["type"] == "base64":
|
||||
data = audio_block["source"]["data"]
|
||||
|
||||
if msg_id not in self._stream_prefix:
|
||||
self._stream_prefix[msg_id] = {}
|
||||
|
||||
audio_prefix = self._stream_prefix[msg_id].get("audio", None)
|
||||
|
||||
import sounddevice as sd
|
||||
|
||||
# The player and the prefix data is cached for streaming audio
|
||||
if audio_prefix:
|
||||
player, audio_prefix_data = audio_prefix
|
||||
else:
|
||||
player = sd.OutputStream(
|
||||
samplerate=24000,
|
||||
channels=1,
|
||||
dtype=np.float32,
|
||||
blocksize=1024,
|
||||
latency="low",
|
||||
)
|
||||
player.start()
|
||||
audio_prefix_data = ""
|
||||
|
||||
# play the audio data
|
||||
new_audio_data = data[len(audio_prefix_data) :]
|
||||
if new_audio_data:
|
||||
audio_bytes = base64.b64decode(new_audio_data)
|
||||
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
||||
audio_float = audio_np.astype(np.float32) / 32768.0
|
||||
|
||||
# Write to the audio output stream
|
||||
player.write(audio_float)
|
||||
|
||||
# save the player and the prefix data
|
||||
self._stream_prefix[msg_id]["audio"] = (
|
||||
player,
|
||||
data,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported audio source type: "
|
||||
f"{audio_block['source']['type']}",
|
||||
)
|
||||
|
||||
def _print_text_block(
|
||||
self,
|
||||
msg_id: str,
|
||||
name_prefix: str,
|
||||
text_content: str,
|
||||
thinking_and_text_to_print: list[str],
|
||||
) -> None:
|
||||
"""Print the text block and thinking block content.
|
||||
|
||||
Args:
|
||||
msg_id (`str`):
|
||||
The unique identifier of the message
|
||||
name_prefix (`str`):
|
||||
The prefix for the message, e.g. "{name}: " for text block and
|
||||
"{name}(thinking): " for thinking block.
|
||||
text_content (`str`):
|
||||
The textual content to be printed.
|
||||
thinking_and_text_to_print (`list[str]`):
|
||||
A list of textual content to be printed together. Here we
|
||||
gather the text and thinking blocks to print them together.
|
||||
"""
|
||||
thinking_and_text_to_print.append(
|
||||
f"{name_prefix}: {text_content}",
|
||||
)
|
||||
# The accumulated text and thinking blocks to print
|
||||
to_print = "\n".join(thinking_and_text_to_print)
|
||||
|
||||
# The text prefix that has been printed
|
||||
if msg_id not in self._stream_prefix:
|
||||
self._stream_prefix[msg_id] = {}
|
||||
|
||||
text_prefix = self._stream_prefix[msg_id].get("text", "")
|
||||
|
||||
# Only print when there is new text content
|
||||
if len(to_print) > len(text_prefix):
|
||||
print(to_print[len(text_prefix) :], end="")
|
||||
|
||||
# Save the printed text prefix
|
||||
self._stream_prefix[msg_id]["text"] = to_print
|
||||
|
||||
def _print_last_block(
|
||||
self,
|
||||
block: ToolUseBlock
|
||||
| ToolResultBlock
|
||||
| ImageBlock
|
||||
| VideoBlock
|
||||
| AudioBlock,
|
||||
msg: Msg,
|
||||
) -> None:
|
||||
"""Process and print the last content block, and the block type
|
||||
is not text, or thinking.
|
||||
|
||||
Args:
|
||||
block (`ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock \
|
||||
| AudioBlock`):
|
||||
The content block to be printed
|
||||
msg (`Msg`):
|
||||
The message object
|
||||
"""
|
||||
# TODO: We should consider how to handle the multimodal blocks in the
|
||||
# terminal, since the base64 data may be too long to display.
|
||||
if block.get("type") in ["image", "video", "audio"]:
|
||||
return
|
||||
|
||||
text_prefix = self._stream_prefix.get(msg.id, {}).get("text", "")
|
||||
|
||||
if text_prefix:
|
||||
# Add a newline to separate from previous text content
|
||||
print_newline = "" if text_prefix.endswith("\n") else "\n"
|
||||
print(
|
||||
f"{print_newline}"
|
||||
f"{json.dumps(block, indent=4, ensure_ascii=False)}",
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{msg.name}:"
|
||||
f" {json.dumps(block, indent=4, ensure_ascii=False)}",
|
||||
)
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Msg:
|
||||
"""Call the reply function with the given arguments."""
|
||||
self._reply_id = shortuuid.uuid()
|
||||
|
||||
reply_msg: Msg | None = None
|
||||
try:
|
||||
self._reply_task = asyncio.current_task()
|
||||
reply_msg = await self.reply(*args, **kwargs)
|
||||
|
||||
# The interruption is triggered by calling the interrupt method
|
||||
except asyncio.CancelledError:
|
||||
reply_msg = await self.handle_interrupt(*args, **kwargs)
|
||||
|
||||
finally:
|
||||
# Broadcast the reply message to all subscribers
|
||||
if reply_msg:
|
||||
await self._broadcast_to_subscribers(reply_msg)
|
||||
self._reply_task = None
|
||||
|
||||
return reply_msg
|
||||
|
||||
async def _broadcast_to_subscribers(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
) -> None:
|
||||
"""Broadcast the message to all subscribers."""
|
||||
for subscribers in self._subscribers.values():
|
||||
for subscriber in subscribers:
|
||||
await subscriber.observe(msg)
|
||||
|
||||
async def handle_interrupt(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Msg:
|
||||
"""The post-processing logic when the reply is interrupted by the
|
||||
user or something else."""
|
||||
raise NotImplementedError(
|
||||
f"The handle_interrupt function is not implemented in "
|
||||
f"{self.__class__.__name__}",
|
||||
)
|
||||
|
||||
async def interrupt(self, msg: Msg | list[Msg] | None = None) -> None:
|
||||
"""Interrupt the current reply process."""
|
||||
if self._reply_task and not self._reply_task.done():
|
||||
self._reply_task.cancel(msg)
|
||||
|
||||
def register_instance_hook(
|
||||
self,
|
||||
hook_type: AgentHookTypes,
|
||||
hook_name: str,
|
||||
hook: Callable,
|
||||
) -> None:
|
||||
"""Register a hook to the agent instance, which only takes effect
|
||||
for the current instance.
|
||||
|
||||
Args:
|
||||
hook_type (`str`):
|
||||
The type of the hook, indicating where the hook is to be
|
||||
triggered.
|
||||
hook_name (`str`):
|
||||
The name of the hook. If the name is already registered, the
|
||||
hook will be overwritten.
|
||||
hook (`Callable`):
|
||||
The hook function.
|
||||
"""
|
||||
if not isinstance(self, AgentBase):
|
||||
raise TypeError(
|
||||
"The register_instance_hook method should be called on an "
|
||||
f"instance of AsyncAgentBase, but got {self} of "
|
||||
f"type {type(self)}.",
|
||||
)
|
||||
hooks = getattr(self, f"_instance_{hook_type}_hooks")
|
||||
hooks[hook_name] = hook
|
||||
|
||||
def remove_instance_hook(
|
||||
self,
|
||||
hook_type: AgentHookTypes,
|
||||
hook_name: str,
|
||||
) -> None:
|
||||
"""Remove an instance-level hook from the agent instance.
|
||||
|
||||
Args:
|
||||
hook_type (`AgentHookTypes`):
|
||||
The type of the hook, indicating where the hook is to be
|
||||
triggered.
|
||||
hook_name (`str`):
|
||||
The name of the hook to remove.
|
||||
"""
|
||||
if not isinstance(self, AgentBase):
|
||||
raise TypeError(
|
||||
"The remove_instance_hook method should be called on an "
|
||||
f"instance of AsyncAgentBase, but got {self} of "
|
||||
f"type {type(self)}.",
|
||||
)
|
||||
hooks = getattr(self, f"_instance_{hook_type}_hooks")
|
||||
if hook_name in hooks:
|
||||
del hooks[hook_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
|
||||
f"{self.__class__.__name__} instance.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_class_hook(
|
||||
cls,
|
||||
hook_type: AgentHookTypes,
|
||||
hook_name: str,
|
||||
hook: Callable,
|
||||
) -> None:
|
||||
"""The universal function to register a hook to the agent class, which
|
||||
will take effect for all instances of the class.
|
||||
|
||||
Args:
|
||||
hook_type (`AgentHookTypes`):
|
||||
The type of the hook, indicating where the hook is to be
|
||||
triggered.
|
||||
hook_name (`str`):
|
||||
The name of the hook. If the name is already registered, the
|
||||
hook will be overwritten.
|
||||
hook (`Callable`):
|
||||
The hook function.
|
||||
"""
|
||||
|
||||
assert (
|
||||
hook_type in cls.supported_hook_types
|
||||
), f"Invalid hook type: {hook_type}"
|
||||
|
||||
hooks = getattr(cls, f"_class_{hook_type}_hooks")
|
||||
hooks[hook_name] = hook
|
||||
|
||||
@classmethod
|
||||
def remove_class_hook(
|
||||
cls,
|
||||
hook_type: AgentHookTypes,
|
||||
hook_name: str,
|
||||
) -> None:
|
||||
"""Remove a class-level hook from the agent class.
|
||||
|
||||
Args:
|
||||
hook_type (`AgentHookTypes`):
|
||||
The type of the hook, indicating where the hook is to be
|
||||
triggered.
|
||||
hook_name (`str`):
|
||||
The name of the hook to remove.
|
||||
"""
|
||||
|
||||
assert (
|
||||
hook_type in cls.supported_hook_types
|
||||
), f"Invalid hook type: {hook_type}"
|
||||
hooks = getattr(cls, f"_class_{hook_type}_hooks")
|
||||
if hook_name in hooks:
|
||||
del hooks[hook_name]
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
|
||||
f"{cls.__name__} class.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear_class_hooks(
|
||||
cls,
|
||||
hook_type: AgentHookTypes | None = None,
|
||||
) -> None:
|
||||
"""Clear all class-level hooks.
|
||||
|
||||
Args:
|
||||
hook_type (`AgentHookTypes`, optional):
|
||||
The type of the hook to clear. If not specified, all
|
||||
class-level hooks will be cleared.
|
||||
"""
|
||||
|
||||
if hook_type is None:
|
||||
for typ in cls.supported_hook_types:
|
||||
hooks = getattr(cls, f"_class_{typ}_hooks")
|
||||
hooks.clear()
|
||||
else:
|
||||
assert (
|
||||
hook_type in cls.supported_hook_types
|
||||
), f"Invalid hook type: {hook_type}"
|
||||
hooks = getattr(cls, f"_class_{hook_type}_hooks")
|
||||
hooks.clear()
|
||||
|
||||
def clear_instance_hooks(
|
||||
self,
|
||||
hook_type: AgentHookTypes | None = None,
|
||||
) -> None:
|
||||
"""If `hook_type` is not specified, clear all instance-level hooks.
|
||||
Otherwise, clear the specified type of instance-level hooks."""
|
||||
if hook_type is None:
|
||||
for typ in self.supported_hook_types:
|
||||
if not hasattr(self, f"_instance_{typ}_hooks"):
|
||||
raise ValueError(
|
||||
f"Call super().__init__() in the constructor "
|
||||
f"to initialize the instance-level hooks for "
|
||||
f"{self.__class__.__name__}.",
|
||||
)
|
||||
hooks = getattr(self, f"_instance_{typ}_hooks")
|
||||
hooks.clear()
|
||||
|
||||
else:
|
||||
assert (
|
||||
hook_type in self.supported_hook_types
|
||||
), f"Invalid hook type: {hook_type}"
|
||||
if not hasattr(self, f"_instance_{hook_type}_hooks"):
|
||||
raise ValueError(
|
||||
f"Call super().__init__() in the constructor "
|
||||
f"to initialize the instance-level hooks for "
|
||||
f"{self.__class__.__name__}.",
|
||||
)
|
||||
hooks = getattr(self, f"_instance_{hook_type}_hooks")
|
||||
hooks.clear()
|
||||
|
||||
def reset_subscribers(
|
||||
self,
|
||||
msghub_name: str,
|
||||
subscribers: list["AgentBase"],
|
||||
) -> None:
|
||||
"""Reset the subscribers of the agent.
|
||||
|
||||
Args:
|
||||
msghub_name (`str`):
|
||||
The name of the MsgHub that manages the subscribers.
|
||||
subscribers (`list[AgentBase]`):
|
||||
A list of agents that will receive the reply message from
|
||||
this agent via their `observe` method.
|
||||
"""
|
||||
self._subscribers[msghub_name] = [_ for _ in subscribers if _ != self]
|
||||
|
||||
def remove_subscribers(self, msghub_name: str) -> None:
|
||||
"""Remove the msghub subscribers by the given msg hub name.
|
||||
|
||||
Args:
|
||||
msghub_name (`str`):
|
||||
The name of the MsgHub that manages the subscribers.
|
||||
"""
|
||||
if msghub_name not in self._subscribers:
|
||||
logger.warning(
|
||||
"MsgHub named '%s' not found",
|
||||
msghub_name,
|
||||
)
|
||||
else:
|
||||
self._subscribers.pop(msghub_name)
|
||||
|
||||
@deprecated("Please use set_console_output_enabled() instead.")
|
||||
def disable_console_output(self) -> None:
|
||||
"""This function will disable the console output of the agent, e.g.
|
||||
in a production environment to avoid messy logs."""
|
||||
self._disable_console_output = True
|
||||
|
||||
def set_console_output_enabled(self, enabled: bool) -> None:
|
||||
"""Enable or disable the console output of the agent. E.g. in a
|
||||
production environment, you may want to disable the console output to
|
||||
avoid messy logs.
|
||||
|
||||
Args:
|
||||
enabled (`bool`):
|
||||
If `True`, enable the console output. If `False`, disable
|
||||
the console output.
|
||||
"""
|
||||
self._disable_console_output = not enabled
|
||||
|
||||
def set_msg_queue_enabled(
|
||||
self,
|
||||
enabled: bool,
|
||||
queue: Queue | None = None,
|
||||
) -> None:
|
||||
"""Enable or disable the message queue for streaming outputs.
|
||||
|
||||
Args:
|
||||
enabled (`bool`):
|
||||
If `True`, enable the message queue to allow streaming
|
||||
outputs. If `False`, disable the message queue.
|
||||
queue (`Queue | None`, optional):
|
||||
The queue instance that will be used to initialize the
|
||||
message queue when `enable` is `True`.
|
||||
"""
|
||||
if enabled:
|
||||
if queue is None:
|
||||
if self.msg_queue is None:
|
||||
self.msg_queue = asyncio.Queue(maxsize=100)
|
||||
else:
|
||||
self.msg_queue = queue
|
||||
else:
|
||||
self.msg_queue = None
|
||||
|
||||
self._disable_msg_queue = not enabled
|
||||
180
src/agentscope/agent/_agent_meta.py
Normal file
180
src/agentscope/agent/_agent_meta.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The metaclass for agents in agentscope."""
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
)
|
||||
|
||||
from .._utils._common import _execute_async_or_sync_func
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agent_base import AgentBase
|
||||
else:
|
||||
AgentBase = "AgentBase"
|
||||
|
||||
|
||||
def _normalize_to_kwargs(
|
||||
func: Callable,
|
||||
self: Any,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Normalize the provided positional and keyword arguments into a
|
||||
keyword arguments dictionary that matches the function signature."""
|
||||
sig = inspect.signature(func)
|
||||
try:
|
||||
# Bind the provided arguments to the function signature
|
||||
bound = sig.bind(self, *args, **kwargs)
|
||||
# Apply the default values for parameters
|
||||
bound.apply_defaults()
|
||||
|
||||
# Return the arguments in a dictionary format
|
||||
res = dict(bound.arguments)
|
||||
res.pop("self")
|
||||
return res
|
||||
|
||||
except TypeError as e:
|
||||
# If failed to bind, we raise a TypeError with more context
|
||||
param_names = list(sig.parameters.keys())
|
||||
provided_args = len(args)
|
||||
provided_kwargs = list(kwargs.keys())
|
||||
|
||||
raise TypeError(
|
||||
f"Failed to bind parameters for function '{func.__name__}': {e}\n"
|
||||
f"Expected parameters: {param_names}\n"
|
||||
f"Provided {provided_args} positional args and kwargs: "
|
||||
f"{provided_kwargs}",
|
||||
) from e
|
||||
|
||||
|
||||
def _wrap_with_hooks(
|
||||
original_func: Callable,
|
||||
) -> Callable:
|
||||
"""A decorator to wrap the original async function with pre- and post-hooks
|
||||
|
||||
Args:
|
||||
original_func (`Callable`):
|
||||
The original async function to be wrapped with hooks.
|
||||
"""
|
||||
func_name = original_func.__name__.replace("_", "")
|
||||
|
||||
@wraps(original_func)
|
||||
async def async_wrapper(
|
||||
self: AgentBase,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""The wrapped function, which call the pre- and post-hooks before and
|
||||
after the original function."""
|
||||
|
||||
# Unify all positional and keyword arguments into a keyword arguments
|
||||
normalized_kwargs = _normalize_to_kwargs(
|
||||
original_func,
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
current_normalized_kwargs = normalized_kwargs
|
||||
assert (
|
||||
hasattr(self, f"_instance_pre_{func_name}_hooks")
|
||||
and hasattr(self, f"_instance_post_{func_name}_hooks")
|
||||
and hasattr(self.__class__, f"_class_pre_{func_name}_hooks")
|
||||
and hasattr(self.__class__, f"_class_post_{func_name}_hooks")
|
||||
), f"Hooks for {func_name} not found in {self.__class__.__name__}"
|
||||
|
||||
# pre-hooks
|
||||
pre_hooks = list(
|
||||
getattr(self, f"_instance_pre_{func_name}_hooks").values(),
|
||||
) + list(
|
||||
getattr(self, f"_class_pre_{func_name}_hooks").values(),
|
||||
)
|
||||
for pre_hook in pre_hooks:
|
||||
modified_keywords = await _execute_async_or_sync_func(
|
||||
pre_hook,
|
||||
self,
|
||||
deepcopy(current_normalized_kwargs),
|
||||
)
|
||||
if modified_keywords is not None:
|
||||
assert isinstance(modified_keywords, dict), (
|
||||
f"Pre-hook must return a dict of keyword arguments, rather"
|
||||
f" than {type(modified_keywords)} from hook "
|
||||
f"{pre_hook.__name__}"
|
||||
)
|
||||
current_normalized_kwargs = modified_keywords
|
||||
|
||||
# original function
|
||||
# handle positional and keyword arguments specifically
|
||||
args = current_normalized_kwargs.get("args", [])
|
||||
kwargs = current_normalized_kwargs.get("kwargs", {})
|
||||
others = {
|
||||
k: v
|
||||
for k, v in current_normalized_kwargs.items()
|
||||
if k not in ["args", "kwargs"]
|
||||
}
|
||||
current_output = await original_func(
|
||||
self,
|
||||
*args,
|
||||
**others,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# post_hooks
|
||||
post_hooks = list(
|
||||
getattr(self, f"_instance_post_{func_name}_hooks").values(),
|
||||
) + list(
|
||||
getattr(self, f"_class_post_{func_name}_hooks").values(),
|
||||
)
|
||||
for post_hook in post_hooks:
|
||||
modified_output = await _execute_async_or_sync_func(
|
||||
post_hook,
|
||||
self,
|
||||
deepcopy(current_normalized_kwargs),
|
||||
deepcopy(current_output),
|
||||
)
|
||||
if modified_output is not None:
|
||||
current_output = modified_output
|
||||
return current_output
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
class _AgentMeta(type):
|
||||
"""The agent metaclass that wraps the agent's reply, observe and print
|
||||
functions with pre- and post-hooks."""
|
||||
|
||||
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
|
||||
"""Wrap the agent's functions with hooks."""
|
||||
|
||||
for func_name in [
|
||||
"reply",
|
||||
"print",
|
||||
"observe",
|
||||
]:
|
||||
if func_name in attrs:
|
||||
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
|
||||
|
||||
return super().__new__(mcs, name, bases, attrs)
|
||||
|
||||
|
||||
class _ReActAgentMeta(_AgentMeta):
|
||||
"""The ReAct metaclass that adds pre- and post-hooks for the _reasoning
|
||||
and _acting functions."""
|
||||
|
||||
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
|
||||
"""Wrap the ReAct agent's _reasoning and _acting functions with
|
||||
hooks."""
|
||||
|
||||
for func_name in [
|
||||
"_reasoning",
|
||||
"_acting",
|
||||
]:
|
||||
if func_name in attrs:
|
||||
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
|
||||
|
||||
return super().__new__(mcs, name, bases, attrs)
|
||||
1133
src/agentscope/agent/_react_agent.py
Normal file
1133
src/agentscope/agent/_react_agent.py
Normal file
File diff suppressed because it is too large
Load Diff
116
src/agentscope/agent/_react_agent_base.py
Normal file
116
src/agentscope/agent/_react_agent_base.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base class for ReAct agent in agentscope."""
|
||||
from abc import abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Any
|
||||
|
||||
from ._agent_base import AgentBase
|
||||
from ._agent_meta import _ReActAgentMeta
|
||||
from ..message import Msg
|
||||
|
||||
|
||||
class ReActAgentBase(AgentBase, metaclass=_ReActAgentMeta):
|
||||
"""The ReAct agent base class.
|
||||
|
||||
To support ReAct algorithm, this class extends the AgentBase class by
|
||||
adding two abstract interfaces: reasoning and acting, while supporting
|
||||
hook functions at four positions: pre-reasoning, post-reasoning,
|
||||
pre-acting, and post-acting by the `_ReActAgentMeta` metaclass.
|
||||
"""
|
||||
|
||||
supported_hook_types: list[str] = [
|
||||
"pre_reply",
|
||||
"post_reply",
|
||||
"pre_print",
|
||||
"post_print",
|
||||
"pre_observe",
|
||||
"post_observe",
|
||||
"pre_reasoning",
|
||||
"post_reasoning",
|
||||
"pre_acting",
|
||||
"post_acting",
|
||||
]
|
||||
"""Supported hook types for the agent base class."""
|
||||
|
||||
_class_pre_reasoning_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"ReActAgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
],
|
||||
dict[str, Any] | None, # The modified kwargs or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level pre-reasoning hooks, taking `self` object, the input
|
||||
arguments as input"""
|
||||
|
||||
_class_post_reasoning_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"ReActAgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
Any, # output
|
||||
],
|
||||
Msg | None, # the modified output message or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level post-reasoning hooks, taking `self` object, the input
|
||||
arguments and the output message as input, and return the modified output
|
||||
message or None if no modification is needed."""
|
||||
|
||||
_class_pre_acting_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"ReActAgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
],
|
||||
dict[str, Any] | None, # The modified kwargs or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level pre-acting hooks, taking `self` object, the input
|
||||
arguments as input, and return the modified input arguments or None if no
|
||||
modification is needed."""
|
||||
|
||||
_class_post_acting_hooks: dict[
|
||||
str,
|
||||
Callable[
|
||||
[
|
||||
"ReActAgentBase", # self
|
||||
dict[str, Any], # kwargs
|
||||
Any, # output
|
||||
],
|
||||
Msg | None, # the modified output message or None
|
||||
],
|
||||
] = OrderedDict()
|
||||
"""The class-level post-acting hooks, taking `self` object, the input
|
||||
arguments and the output message as input, and return the modified output
|
||||
message or None if no modification is needed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
"""Initialize the ReAct agent base class."""
|
||||
super().__init__()
|
||||
|
||||
# Init reasoning and acting hooks
|
||||
self._instance_pre_reasoning_hooks = OrderedDict()
|
||||
self._instance_post_reasoning_hooks = OrderedDict()
|
||||
self._instance_pre_acting_hooks = OrderedDict()
|
||||
self._instance_post_acting_hooks = OrderedDict()
|
||||
|
||||
@abstractmethod
|
||||
async def _reasoning(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""The reasoning process of the ReAct agent, which will be wrapped
|
||||
with pre- and post-hooks."""
|
||||
|
||||
@abstractmethod
|
||||
async def _acting(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""The acting process of the ReAct agent, which will be wrapped with
|
||||
pre- and post-hooks."""
|
||||
360
src/agentscope/agent/_realtime_agent.py
Normal file
360
src/agentscope/agent/_realtime_agent.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The realtime agent class."""
|
||||
import asyncio
|
||||
from asyncio import Queue
|
||||
|
||||
import shortuuid
|
||||
|
||||
from .._logging import logger
|
||||
from .._utils._common import _resample_pcm_delta
|
||||
from ..message import (
|
||||
AudioBlock,
|
||||
Base64Source,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
)
|
||||
from ..module import StateModule
|
||||
from ..realtime import (
|
||||
ModelEvents,
|
||||
RealtimeModelBase,
|
||||
ServerEvents,
|
||||
ClientEvents,
|
||||
)
|
||||
from ..tool import Toolkit
|
||||
|
||||
|
||||
class RealtimeAgent(StateModule):
|
||||
"""The realtime agent class. Different from the `AgentBase` class,
|
||||
this class is designed for real-time interaction scenarios, such as
|
||||
realtime chat, voice assistants, etc.
|
||||
|
||||
Example:
|
||||
This realtime agent requires a queue to handle outgoing messages to
|
||||
the frontend and other agents, and its lifecycle is managed by the
|
||||
`start` and `stop` methods.
|
||||
|
||||
.. code-block:: python
|
||||
:caption: An example of using the RealtimeAgent class.
|
||||
|
||||
from agentscope.agent import RealtimeAgent
|
||||
from agentscope.realtime import DashScopeRealtimeModel
|
||||
import asyncio
|
||||
|
||||
agent = RealtimeAgent(
|
||||
name="Friday",
|
||||
sys_prompt="You are a helpful assistant.",
|
||||
model=DashScopeRealtimeModel(
|
||||
model_name="qwen3-omni-flash-realtime",
|
||||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||||
)
|
||||
)
|
||||
|
||||
queue = asyncio.Queue()
|
||||
await agent.start(queue)
|
||||
|
||||
# handle the outgoing messages from the agent in another asyncio
|
||||
# task
|
||||
...
|
||||
|
||||
await agent.stop()
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
sys_prompt: str,
|
||||
model: RealtimeModelBase,
|
||||
toolkit: Toolkit | None = None,
|
||||
) -> None:
|
||||
"""Initialize the RealtimeAgent class.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the agent.
|
||||
sys_prompt (`str`):
|
||||
The system prompt of the agent.
|
||||
model (`RealtimeModelBase`):
|
||||
The realtime model used by the agent.
|
||||
toolkit (`Toolkit | None`, optional):
|
||||
A `Toolkit` object that contains the tool functions. If not
|
||||
provided, a default empty `Toolkit` will be created.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.id = shortuuid.uuid()
|
||||
self.name = name
|
||||
self.sys_prompt = sys_prompt
|
||||
self.model = model
|
||||
self.toolkit = toolkit
|
||||
|
||||
# A queue to handle the incoming events from other agents or the
|
||||
# frontend.
|
||||
self._incoming_queue = Queue()
|
||||
self._external_event_handling_task = None
|
||||
|
||||
# The queue to gather model responses.
|
||||
self._model_response_queue = Queue()
|
||||
self._model_response_handling_task = None
|
||||
|
||||
async def start(self, outgoing_queue: Queue) -> None:
|
||||
"""Establish a connection for real-time interaction.
|
||||
|
||||
Args:
|
||||
outgoing_queue (`Queue`):
|
||||
The queue to push messages to the frontend and other agents.
|
||||
"""
|
||||
# Start the realtime model connection.
|
||||
await self.model.connect(
|
||||
self._model_response_queue,
|
||||
instructions=self.sys_prompt,
|
||||
tools=self.toolkit.get_json_schemas() if self.toolkit else None,
|
||||
)
|
||||
|
||||
# Start the forwarding loop.
|
||||
self._external_event_handling_task = asyncio.create_task(
|
||||
self._forward_loop(),
|
||||
)
|
||||
|
||||
# Start the response handling loop.
|
||||
self._model_response_handling_task = asyncio.create_task(
|
||||
self._model_response_loop(outgoing_queue),
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Close the connection."""
|
||||
|
||||
if not self._external_event_handling_task.done():
|
||||
self._external_event_handling_task.cancel()
|
||||
|
||||
await self.model.disconnect()
|
||||
|
||||
async def _forward_loop(self) -> None:
|
||||
"""The loop to forward messages from other agents or the frontend to
|
||||
the realtime model for processing.
|
||||
|
||||
outside ==> agent ==> realtime model
|
||||
"""
|
||||
logger.info(
|
||||
"Agent '%s' begins the loops to receive external events",
|
||||
self.name,
|
||||
)
|
||||
|
||||
while True:
|
||||
event = await self._incoming_queue.get()
|
||||
|
||||
match event:
|
||||
# Only handle the events that we need
|
||||
case ServerEvents.AgentResponseAudioDeltaEvent() as event:
|
||||
# Convert the sample rate to the required format by the
|
||||
# model
|
||||
receive_rate = event.format.rate
|
||||
if self.model.input_sample_rate != receive_rate:
|
||||
delta = _resample_pcm_delta(
|
||||
event.delta,
|
||||
receive_rate,
|
||||
self.model.input_sample_rate,
|
||||
)
|
||||
|
||||
else:
|
||||
delta = event.delta
|
||||
|
||||
await self.model.send(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
type="base64",
|
||||
media_type=event.format.type,
|
||||
data=delta,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
case ServerEvents.AgentResponseAudioDoneEvent():
|
||||
# Send a silence audio block to indicate the end of audio
|
||||
pass
|
||||
|
||||
case ClientEvents.ClientAudioAppendEvent() as event:
|
||||
# Construct media_type from format info
|
||||
# format contains: {"sample_rate": 16000, "encoding":
|
||||
# "pcm16"}
|
||||
# encoding = event.format.get("encoding", "pcm16")
|
||||
# media_type = (
|
||||
# f"audio/{encoding.replace('16', '')}"
|
||||
# if "pcm" in encoding
|
||||
# else "audio/pcm"
|
||||
# )
|
||||
|
||||
await self.model.send(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
type="base64",
|
||||
media_type=event.format.type,
|
||||
data=event.audio,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
case ClientEvents.ClientTextAppendEvent() as event:
|
||||
await self.model.send(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=event.text,
|
||||
),
|
||||
)
|
||||
case ClientEvents.ClientImageAppendEvent() as event:
|
||||
# Construct media_type from format info
|
||||
media_type = event.format.get("type", "image/jpeg")
|
||||
|
||||
await self.model.send(
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=Base64Source(
|
||||
type="base64",
|
||||
media_type=media_type,
|
||||
data=event.image,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
async def _model_response_loop(self, outgoing_queue: Queue) -> None:
|
||||
"""The loop to handle model responses and forward them to the
|
||||
frontend and other agents.
|
||||
|
||||
realtime model ==> agent ==> outside
|
||||
|
||||
Args:
|
||||
outgoing_queue (`Queue`):
|
||||
The queue to push messages to the frontend and other agents.
|
||||
"""
|
||||
while True:
|
||||
model_event = await self._model_response_queue.get()
|
||||
|
||||
agent_kwargs = {"agent_id": self.id, "agent_name": self.name}
|
||||
|
||||
agent_event = None
|
||||
match model_event:
|
||||
# The events that can be converted from model events to agent
|
||||
# events directly
|
||||
case (
|
||||
ModelEvents.ModelResponseCreatedEvent()
|
||||
| ModelEvents.ModelResponseDoneEvent()
|
||||
| ModelEvents.ModelResponseAudioDeltaEvent()
|
||||
| ModelEvents.ModelResponseAudioDoneEvent()
|
||||
| ModelEvents.ModelResponseAudioTranscriptDeltaEvent()
|
||||
| ModelEvents.ModelResponseAudioTranscriptDoneEvent()
|
||||
| ModelEvents.ModelResponseToolUseDeltaEvent()
|
||||
| ModelEvents.ModelInputTranscriptionDeltaEvent()
|
||||
| ModelEvents.ModelInputTranscriptionDoneEvent()
|
||||
| ModelEvents.ModelInputStartedEvent()
|
||||
| ModelEvents.ModelInputDoneEvent()
|
||||
| ModelEvents.ModelErrorEvent()
|
||||
) as event:
|
||||
# Directly map the model event to agent event
|
||||
agent_event = ServerEvents.from_model_event(
|
||||
event,
|
||||
**agent_kwargs,
|
||||
)
|
||||
|
||||
# The events that need special handling
|
||||
case ModelEvents.ModelSessionCreatedEvent():
|
||||
# Send the agent ready event to the outside.
|
||||
agent_event = ServerEvents.AgentReadyEvent(**agent_kwargs)
|
||||
|
||||
case ModelEvents.ModelSessionEndedEvent():
|
||||
# Send the agent session ended event to the outside.
|
||||
agent_event = ServerEvents.AgentEndedEvent(**agent_kwargs)
|
||||
|
||||
# The tool use done that requires executing the tool
|
||||
# Such event may generate multiple outgoing events:
|
||||
# 1. Tool use done event
|
||||
# 2. Tool result event
|
||||
case ModelEvents.ModelResponseToolUseDoneEvent() as event:
|
||||
# Send the tool use done event immediately
|
||||
done_event = ServerEvents.AgentResponseToolUseDoneEvent(
|
||||
response_id=event.response_id,
|
||||
item_id=event.item_id,
|
||||
tool_use=event.tool_use,
|
||||
**agent_kwargs,
|
||||
)
|
||||
|
||||
# Directly put the done event to the outgoing queue
|
||||
await outgoing_queue.put(done_event)
|
||||
|
||||
# Then execute the tool call using accumulated arguments
|
||||
if self.toolkit:
|
||||
# Execute the tool call asynchronously
|
||||
asyncio.create_task(
|
||||
self._acting(
|
||||
tool_use=event.tool_use,
|
||||
outgoing_queue=outgoing_queue,
|
||||
),
|
||||
)
|
||||
|
||||
case _:
|
||||
logger.debug(
|
||||
"Unknown model event type: %s",
|
||||
type(model_event),
|
||||
)
|
||||
|
||||
if agent_event is not None:
|
||||
# Put the processed response to the outgoing queue.
|
||||
await outgoing_queue.put(agent_event)
|
||||
|
||||
async def handle_input(
|
||||
self,
|
||||
event: ClientEvents.EventBase | ServerEvents.EventBase,
|
||||
) -> None:
|
||||
"""Handle the input message from the frontend or the other agents.
|
||||
|
||||
Args:
|
||||
event (`ClientEvents.EventBase | ServerEvents.EventBase`):
|
||||
The input event from the frontend or other agents.
|
||||
"""
|
||||
await self._incoming_queue.put(event)
|
||||
|
||||
async def _acting(
|
||||
self,
|
||||
tool_use: ToolUseBlock,
|
||||
outgoing_queue: Queue,
|
||||
) -> None:
|
||||
"""Execute the tool call and send the result back to the outside (
|
||||
frontend or other agents).
|
||||
|
||||
Args:
|
||||
tool_use (`ToolUseBlock`):
|
||||
The tool use block containing the tool call information.
|
||||
outgoing_queue (`Queue`):
|
||||
The queue to push messages to the frontend and other agents.
|
||||
"""
|
||||
if not self.toolkit:
|
||||
return
|
||||
|
||||
res = await self.toolkit.call_tool_function(tool_use)
|
||||
|
||||
last_chunk = None
|
||||
async for chunk in res:
|
||||
last_chunk = chunk
|
||||
|
||||
if last_chunk:
|
||||
tool_result_block = ToolResultBlock(
|
||||
type="tool_result",
|
||||
id=tool_use.get("id"),
|
||||
name=tool_use.get("name"),
|
||||
output=last_chunk.content,
|
||||
)
|
||||
|
||||
# Send the tool result back to the model
|
||||
await self.model.send(tool_result_block)
|
||||
|
||||
# Also send event to frontend/other agents
|
||||
outgoing_event = ServerEvents.AgentResponseToolResultEvent(
|
||||
tool_result=tool_result_block,
|
||||
agent_id=self.id,
|
||||
agent_name=self.name,
|
||||
)
|
||||
|
||||
await outgoing_queue.put(outgoing_event)
|
||||
128
src/agentscope/agent/_user_agent.py
Normal file
128
src/agentscope/agent/_user_agent.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The user agent class."""
|
||||
from typing import Type, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agent_base import AgentBase
|
||||
from ._user_input import UserInputBase, TerminalUserInput
|
||||
from ..message import Msg
|
||||
|
||||
|
||||
class UserAgent(AgentBase):
|
||||
"""The class for user interaction, allowing developers to handle the user
|
||||
input from different sources, such as web UI, cli, and other interfaces.
|
||||
"""
|
||||
|
||||
_input_method: UserInputBase = TerminalUserInput()
|
||||
"""The user input method, can be overridden by calling the
|
||||
`register_instance/class_input_method` function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Initialize the user agent with a name."""
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> Msg:
|
||||
"""Receive input message(s) and generate a reply message from the user.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`, defaults to `None`):
|
||||
The message(s) to be replied. If `None`, the agent will wait
|
||||
for user input.
|
||||
structured_model (`Type[BaseModel] | None`, defaults to `None`):
|
||||
A child class of `pydantic.BaseModel` that defines the
|
||||
structured output format. If provided, the user will be
|
||||
prompted to fill in the required fields.
|
||||
|
||||
Returns:
|
||||
`Msg`:
|
||||
The reply message generated by the user.
|
||||
"""
|
||||
|
||||
# Get the input from the specified input method.
|
||||
input_data = await self._input_method(
|
||||
agent_id=self.id,
|
||||
agent_name=self.name,
|
||||
structured_model=structured_model,
|
||||
)
|
||||
|
||||
blocks_input = input_data.blocks_input
|
||||
if (
|
||||
blocks_input
|
||||
and len(blocks_input) == 1
|
||||
and blocks_input[0].get("type") == "text"
|
||||
):
|
||||
# Turn blocks_input into a string if only one text block exists
|
||||
blocks_input = blocks_input[0].get("text")
|
||||
|
||||
msg = Msg(
|
||||
self.name,
|
||||
content=blocks_input,
|
||||
role="user",
|
||||
metadata=input_data.structured_input,
|
||||
)
|
||||
|
||||
await self.print(msg)
|
||||
|
||||
return msg
|
||||
|
||||
def override_instance_input_method(
|
||||
self,
|
||||
input_method: UserInputBase,
|
||||
) -> None:
|
||||
"""Override the input method of the current UserAgent instance.
|
||||
|
||||
Args:
|
||||
input_method (`UserInputBase`):
|
||||
The callable input method, which should be an object of a
|
||||
class that inherits from `UserInputBase`.
|
||||
"""
|
||||
if not isinstance(input_method, UserInputBase):
|
||||
raise ValueError(
|
||||
f"The input method should be an instance of the child class "
|
||||
f"of `UserInputBase`, but got {type(input_method)} instead.",
|
||||
)
|
||||
self._input_method = input_method
|
||||
|
||||
@classmethod
|
||||
def override_class_input_method(
|
||||
cls,
|
||||
input_method: UserInputBase,
|
||||
) -> None:
|
||||
"""Override the input method of the current UserAgent class.
|
||||
|
||||
Args:
|
||||
input_method (`UserInputBase`):
|
||||
The callable input method, which should be an object of a
|
||||
class that inherits from `UserInputBase`.
|
||||
"""
|
||||
if not isinstance(input_method, UserInputBase):
|
||||
raise ValueError(
|
||||
f"The input method should be an instance of the child class "
|
||||
f"of `UserInputBase`, but got {type(input_method)} instead.",
|
||||
)
|
||||
cls._input_method = input_method
|
||||
|
||||
async def handle_interrupt(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Msg:
|
||||
"""The post-processing logic when the reply is interrupted by the
|
||||
user or something else."""
|
||||
raise NotImplementedError(
|
||||
f"The handle_interrupt function is not implemented in "
|
||||
f"{self.__class__.__name__}",
|
||||
)
|
||||
|
||||
async def observe(self, msg: Msg | list[Msg] | None) -> None:
|
||||
"""Observe the message(s) from the other agents or the environment."""
|
||||
415
src/agentscope/agent/_user_input.py
Normal file
415
src/agentscope/agent/_user_input.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The user input related classes."""
|
||||
import json.decoder
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from queue import Queue
|
||||
from threading import Event
|
||||
from typing import Any, Type, List
|
||||
|
||||
import jsonschema
|
||||
import requests
|
||||
import shortuuid
|
||||
import socketio
|
||||
from pydantic import BaseModel
|
||||
import json5
|
||||
|
||||
from .. import _config
|
||||
from .._logging import logger
|
||||
from ..message import (
|
||||
TextBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
ImageBlock,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInputData:
|
||||
"""The user input data."""
|
||||
|
||||
blocks_input: List[TextBlock | ImageBlock | AudioBlock | VideoBlock] = None
|
||||
"""The text input from the user"""
|
||||
|
||||
structured_input: dict[str, Any] | None = None
|
||||
"""The structured input from the user"""
|
||||
|
||||
|
||||
class UserInputBase:
|
||||
"""The base class used to handle the user input from different sources."""
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_name: str,
|
||||
*args: Any,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> UserInputData:
|
||||
"""The user input method, which returns the user input and the
|
||||
required structured data.
|
||||
|
||||
Args:
|
||||
agent_id (`str`):
|
||||
The agent identifier.
|
||||
agent_name (`str`):
|
||||
The agent name.
|
||||
structured_model (`Type[BaseModel] | None`, optional):
|
||||
A base model class that defines the structured input format.
|
||||
|
||||
Returns:
|
||||
`UserInputData`:
|
||||
The user input data.
|
||||
"""
|
||||
|
||||
|
||||
class TerminalUserInput(UserInputBase):
|
||||
"""The terminal user input."""
|
||||
|
||||
def __init__(self, input_hint: str = "User Input: ") -> None:
|
||||
"""Initialize the terminal user input with a hint."""
|
||||
self.input_hint = input_hint
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_name: str,
|
||||
*args: Any,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> UserInputData:
|
||||
"""Handle the user input from the terminal.
|
||||
|
||||
Args:
|
||||
agent_id (`str`):
|
||||
The agent identifier.
|
||||
agent_name (`str`):
|
||||
The agent name.
|
||||
structured_model (`Type[BaseModel] | None`, optional):
|
||||
A base model class that defines the structured input format.
|
||||
|
||||
Returns:
|
||||
`UserInputData`:
|
||||
The user input data.
|
||||
"""
|
||||
|
||||
text_input = (
|
||||
input(self.input_hint)
|
||||
.encode("utf-8", errors="ignore")
|
||||
.decode("utf-8")
|
||||
)
|
||||
|
||||
structured_input = None
|
||||
if structured_model is not None:
|
||||
structured_input = {}
|
||||
|
||||
json_schema = structured_model.model_json_schema()
|
||||
required = json_schema.get("required", [])
|
||||
print("Structured input (press Enter to skip for optional):)")
|
||||
|
||||
for key, item in json_schema.get("properties").items():
|
||||
requirements = {**item}
|
||||
requirements.pop("title")
|
||||
|
||||
while True:
|
||||
res = input(f"\t{key} ({requirements}): ")
|
||||
|
||||
if res == "":
|
||||
if key in required:
|
||||
print(f"Key {key} is required.")
|
||||
continue
|
||||
|
||||
res = item.get("default", None)
|
||||
|
||||
if item.get("type").lower() == "integer":
|
||||
try:
|
||||
res = json5.loads(res)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
print(
|
||||
"\033[31mInvalid input with error:\n"
|
||||
"```\n"
|
||||
f"{e}\n"
|
||||
"```\033[0m",
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
jsonschema.validate(res, item)
|
||||
structured_input[key] = res
|
||||
break
|
||||
except jsonschema.ValidationError as e:
|
||||
print(
|
||||
f"\033[31mValidation error:\n```\n{e}\n```\033[0m",
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
return UserInputData(
|
||||
blocks_input=[TextBlock(type="text", text=text_input)],
|
||||
structured_input=structured_input,
|
||||
)
|
||||
|
||||
|
||||
class StudioUserInput(UserInputBase):
|
||||
"""The class that host the user input on the AgentScope Studio."""
|
||||
|
||||
_websocket_namespace: str = "/python"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
studio_url: str,
|
||||
run_id: str,
|
||||
max_retries: int = 3,
|
||||
reconnect_attempts: int = 3,
|
||||
reconnection_delay: int = 1,
|
||||
reconnection_delay_max: int = 5,
|
||||
) -> None:
|
||||
"""Initialize the StudioUserInput object.
|
||||
|
||||
Args:
|
||||
studio_url (`str`):
|
||||
The URL of the AgentScope Studio.
|
||||
run_id (`str`):
|
||||
The current run identity.
|
||||
max_retries (`int`, defaults to `3`):
|
||||
The maximum number of retries to get user input.
|
||||
"""
|
||||
self._is_connected = False
|
||||
self._is_reconnecting = False
|
||||
|
||||
self.studio_url = studio_url
|
||||
self.run_id = run_id
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Init Websocket
|
||||
self.sio = socketio.Client(
|
||||
reconnection=True,
|
||||
reconnection_attempts=reconnect_attempts,
|
||||
reconnection_delay=reconnection_delay,
|
||||
reconnection_delay_max=reconnection_delay_max,
|
||||
)
|
||||
self.input_queues = {}
|
||||
self.input_events = {}
|
||||
|
||||
@self.sio.on("connect", namespace=self._websocket_namespace)
|
||||
def on_connect() -> None:
|
||||
self._is_connected = True
|
||||
logger.info(
|
||||
'Connected to AgentScope Studio at "%s" with '
|
||||
'run name "%s".',
|
||||
self.studio_url,
|
||||
run_id,
|
||||
)
|
||||
logger.info(
|
||||
"View the run at: %s/projects/%s",
|
||||
self.studio_url,
|
||||
_config.project,
|
||||
)
|
||||
|
||||
@self.sio.on("disconnect", namespace=self._websocket_namespace)
|
||||
def on_disconnect() -> None:
|
||||
self._is_connected = False
|
||||
logger.info(
|
||||
"Disconnected from AgentScope Studio at %s",
|
||||
self.studio_url,
|
||||
)
|
||||
|
||||
@self.sio.on("reconnect", namespace=self._websocket_namespace)
|
||||
def on_reconnect(attempt_number: int) -> None:
|
||||
self._is_connected = True
|
||||
self._is_reconnecting = False
|
||||
logger.info(
|
||||
"Reconnected to AgentScope Studio at %s with run_id %s after "
|
||||
"%d attempts",
|
||||
self.studio_url,
|
||||
self.run_id,
|
||||
attempt_number,
|
||||
)
|
||||
|
||||
@self.sio.on("reconnect_attempt", namespace=self._websocket_namespace)
|
||||
def on_reconnect_attempt(attempt_number: int) -> None:
|
||||
self._is_reconnecting = True
|
||||
logger.info(
|
||||
"Attempting to reconnect to AgentScope Studio at %s "
|
||||
"(attempt %d)",
|
||||
self.studio_url,
|
||||
attempt_number,
|
||||
)
|
||||
|
||||
@self.sio.on("reconnect_failed", namespace=self._websocket_namespace)
|
||||
def on_reconnect_failed() -> None:
|
||||
self._is_reconnecting = False
|
||||
logger.error(
|
||||
"Failed to reconnect to AgentScope Studio at %s",
|
||||
self.studio_url,
|
||||
)
|
||||
|
||||
@self.sio.on("reconnect_error", namespace=self._websocket_namespace)
|
||||
def on_reconnect_error(error: Any) -> None:
|
||||
logger.error(
|
||||
"Error while reconnecting to AgentScope Studio at %s: %s",
|
||||
self.studio_url,
|
||||
str(error),
|
||||
)
|
||||
|
||||
# The AgentScope Studio backend send the "sendUserInput" event to
|
||||
# the current python run
|
||||
@self.sio.on("forwardUserInput", namespace=self._websocket_namespace)
|
||||
def receive_user_input(
|
||||
request_id: str,
|
||||
blocks_input: List[
|
||||
TextBlock | ImageBlock | AudioBlock | VideoBlock
|
||||
],
|
||||
structured_input: dict[str, Any],
|
||||
) -> None:
|
||||
if request_id in self.input_queues:
|
||||
self.input_queues[request_id].put(
|
||||
UserInputData(
|
||||
blocks_input=blocks_input,
|
||||
structured_input=structured_input,
|
||||
),
|
||||
)
|
||||
self.input_events[request_id].set()
|
||||
|
||||
try:
|
||||
self.sio.connect(
|
||||
f"{self.studio_url}",
|
||||
namespaces=["/python"],
|
||||
auth={"run_id": self.run_id},
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to AgentScope Studio at {self.studio_url}",
|
||||
) from e
|
||||
|
||||
def _ensure_connected(
|
||||
self,
|
||||
timeout: float = 30.0,
|
||||
check_interval: float = 5.0,
|
||||
) -> None:
|
||||
"""Ensure the connection is established or wait for reconnection.
|
||||
|
||||
Args:
|
||||
timeout (`float`):
|
||||
Maximum time to wait for reconnection in seconds. Defaults
|
||||
to 30.0.
|
||||
check_interval (`float`):
|
||||
Interval between connection checks in seconds. Defaults to 1.0.
|
||||
|
||||
Raises:
|
||||
`RuntimeError`:
|
||||
If connection cannot be established within timeout.
|
||||
"""
|
||||
if self._is_connected:
|
||||
return
|
||||
|
||||
if self._is_reconnecting:
|
||||
start_time = time.time()
|
||||
while self._is_reconnecting:
|
||||
# Check timeout
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
raise RuntimeError(
|
||||
f"Reconnection timeout after {elapsed_time} seconds",
|
||||
)
|
||||
|
||||
# Log status
|
||||
logger.info(
|
||||
"Waiting for reconnection... (%.1fs / %.1fs)",
|
||||
elapsed_time,
|
||||
timeout,
|
||||
)
|
||||
|
||||
# Wait for next check
|
||||
time.sleep(check_interval)
|
||||
|
||||
# After reconnection attempt completed, check final status
|
||||
if self._is_connected:
|
||||
return
|
||||
|
||||
# Not connected and not reconnecting
|
||||
raise RuntimeError(
|
||||
f"Not connected to AgentScope Studio at {self.studio_url}.",
|
||||
)
|
||||
|
||||
async def __call__( # type: ignore[override]
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_name: str,
|
||||
*args: Any,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> UserInputData:
|
||||
"""Get the user input from AgentScope Studio.
|
||||
|
||||
Args:
|
||||
agent_id (`str`):
|
||||
The identity of the agent.
|
||||
agent_name (`str`):
|
||||
The name of the agent.
|
||||
structured_model (`Type[BaseModel] | None`, optional):
|
||||
The base model class of the structured input.
|
||||
|
||||
Raises:
|
||||
`RuntimeError`:
|
||||
Failed to get user input from AgentScope Studio.
|
||||
|
||||
Returns:
|
||||
`UserInputData`:
|
||||
The user input.
|
||||
"""
|
||||
self._ensure_connected()
|
||||
|
||||
request_id = shortuuid.uuid()
|
||||
|
||||
self.input_queues[request_id] = Queue()
|
||||
self.input_events[request_id] = Event()
|
||||
|
||||
if structured_model is None:
|
||||
structured_input = None
|
||||
else:
|
||||
structured_input = structured_model.model_json_schema()
|
||||
|
||||
n_retry = 0
|
||||
while True:
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.studio_url}/trpc/requestUserInput",
|
||||
json={
|
||||
"requestId": request_id,
|
||||
"runId": self.run_id,
|
||||
"agentId": agent_id,
|
||||
"agentName": agent_name,
|
||||
"structuredInput": structured_input,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
break
|
||||
except Exception as e:
|
||||
if n_retry < self.max_retries:
|
||||
n_retry += 1
|
||||
continue
|
||||
|
||||
raise RuntimeError(
|
||||
"Failed to get user input from AgentScope Studio",
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.input_events[request_id].wait()
|
||||
response_data = self.input_queues[request_id].get()
|
||||
return response_data
|
||||
|
||||
finally:
|
||||
self.input_queues.pop(request_id, None)
|
||||
self.input_events.pop(request_id, None)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup socket connection when object it destroyed"""
|
||||
try:
|
||||
self.sio.disconnect()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to disconnect from AgentScope Studio at %s: %s",
|
||||
self.studio_url,
|
||||
str(e),
|
||||
)
|
||||
18
src/agentscope/agent/_utils.py
Normal file
18
src/agentscope/agent/_utils.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Utils for agents in agentscope."""
|
||||
from typing import Any
|
||||
|
||||
|
||||
class _AsyncNullContext:
|
||||
"""An async null context manager."""
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
return None
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Any,
|
||||
exc_val: Any,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
return None
|
||||
27
src/agentscope/embedding/__init__.py
Normal file
27
src/agentscope/embedding/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding module in agentscope."""
|
||||
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._dashscope_embedding import DashScopeTextEmbedding
|
||||
from ._dashscope_multimodal_embedding import DashScopeMultiModalEmbedding
|
||||
from ._openai_embedding import OpenAITextEmbedding
|
||||
from ._gemini_embedding import GeminiTextEmbedding
|
||||
from ._ollama_embedding import OllamaTextEmbedding
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._file_cache import FileEmbeddingCache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingModelBase",
|
||||
"EmbeddingUsage",
|
||||
"EmbeddingResponse",
|
||||
"DashScopeTextEmbedding",
|
||||
"DashScopeMultiModalEmbedding",
|
||||
"OpenAITextEmbedding",
|
||||
"GeminiTextEmbedding",
|
||||
"OllamaTextEmbedding",
|
||||
"EmbeddingCacheBase",
|
||||
"FileEmbeddingCache",
|
||||
]
|
||||
63
src/agentscope/embedding/_cache_base.py
Normal file
63
src/agentscope/embedding/_cache_base.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding cache base class."""
|
||||
from abc import abstractmethod
|
||||
from typing import List, Any
|
||||
|
||||
from ..types import (
|
||||
JSONSerializableObject,
|
||||
Embedding,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingCacheBase:
|
||||
"""Base class for embedding caches, which is responsible for storing and
|
||||
retrieving embeddings."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(
|
||||
self,
|
||||
embeddings: List[Embedding],
|
||||
identifier: JSONSerializableObject,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
embeddings (`List[Embedding]`):
|
||||
The embeddings to store.
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to distinguish the embeddings.
|
||||
overwrite (`bool`, defaults to `False`):
|
||||
Whether to overwrite existing embeddings with the same
|
||||
identifier. If `True`, existing embeddings will be replaced.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> List[Embedding] | None:
|
||||
"""Retrieve the embeddings with the given identifier. If not
|
||||
found, return `None`.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to retrieve the embeddings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def remove(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> None:
|
||||
"""Remove the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to remove the embeddings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cached embeddings."""
|
||||
169
src/agentscope/embedding/_dashscope_embedding.py
Normal file
169
src/agentscope/embedding/_dashscope_embedding.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope embedding module in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from .._logging import logger
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class DashScopeTextEmbedding(EmbeddingModelBase):
|
||||
"""DashScope text embedding API class.
|
||||
|
||||
.. note:: From the `official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_:
|
||||
|
||||
- The max batch size that DashScope text embedding API
|
||||
supports is 10 for `text-embedding-v4` and `text-embedding-v3` models, and
|
||||
25 for `text-embedding-v2` and `text-embedding-v1` models.
|
||||
- The max token limit for a single input is 8192 tokens for `v4` and `v3`
|
||||
models, and 2048 tokens for `v2` and `v1` models.
|
||||
|
||||
"""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.api_key = api_key
|
||||
self.embedding_cache = embedding_cache
|
||||
self.batch_size_limit = 10
|
||||
|
||||
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
|
||||
"""Call the DashScope embedding API by the given keyword arguments."""
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
start_time = datetime.now()
|
||||
response = dashscope.embeddings.TextEmbedding.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embedding from DashScope API: {response}",
|
||||
)
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[
|
||||
_["embedding"] for _ in response.output["embeddings"]
|
||||
],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_["embedding"] for _ in response.output["embeddings"]],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage["total_tokens"],
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the DashScope embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
if len(gather_text) > self.batch_size_limit:
|
||||
logger.info(
|
||||
"The input texts (%d) will be embedded with %d API calls due "
|
||||
f"to the batch size limit of {self.batch_size_limit} for "
|
||||
f"DashScope embedding API.",
|
||||
len(gather_text),
|
||||
(len(gather_text) + self.batch_size_limit - 1)
|
||||
// self.batch_size_limit,
|
||||
)
|
||||
|
||||
# Handle the batch size limit for DashScope embedding API
|
||||
collected_embeddings = []
|
||||
collected_time = 0.0
|
||||
collected_tokens = 0
|
||||
collected_source: Literal["cache", "api"] = "cache"
|
||||
for _ in range(0, len(gather_text), self.batch_size_limit):
|
||||
batch_texts = gather_text[_ : _ + self.batch_size_limit]
|
||||
batch_kwargs = {
|
||||
"input": batch_texts,
|
||||
"model": self.model_name,
|
||||
"dimension": self.dimensions,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
res = await self._call_api(batch_kwargs)
|
||||
|
||||
collected_embeddings.extend(res.embeddings)
|
||||
collected_time += res.usage.time
|
||||
if res.usage.tokens:
|
||||
collected_tokens += res.usage.tokens
|
||||
if res.source == "api":
|
||||
collected_source = "api"
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=collected_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=collected_tokens,
|
||||
time=collected_time,
|
||||
),
|
||||
source=collected_source,
|
||||
)
|
||||
244
src/agentscope/embedding/_dashscope_multimodal_embedding.py
Normal file
244
src/agentscope/embedding/_dashscope_multimodal_embedding.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope multimodal embedding model in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import (
|
||||
VideoBlock,
|
||||
ImageBlock,
|
||||
TextBlock,
|
||||
)
|
||||
|
||||
|
||||
class DashScopeMultiModalEmbedding(EmbeddingModelBase):
|
||||
"""The DashScope multimodal embedding API, supporting text, image and
|
||||
video embedding."""
|
||||
|
||||
supported_modalities: list[str] = ["text", "image", "video"]
|
||||
"""This class supports text, image and video input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int | None = None,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multimodal embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model, e.g. "multimodal-embedding-
|
||||
v1", "tongyi-embedding-vision-plus".
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712517>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
path_doc = (
|
||||
"https://bailian.console.aliyun.com/?tab=api#/api/?type=model&"
|
||||
"url=2712517"
|
||||
)
|
||||
self.batch_size_limit = 1
|
||||
|
||||
if model_name.startswith("tongyi-embedding-vision-plus"):
|
||||
self.batch_size_limit = 8
|
||||
if dimensions is None:
|
||||
dimensions = 1152
|
||||
elif dimensions != 1152:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 1152, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
if model_name.startswith("tongyi-embedding-vision-flash"):
|
||||
self.batch_size_limit = 8
|
||||
if dimensions is None:
|
||||
dimensions = 768
|
||||
elif dimensions != 768:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 768, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
if model_name.startswith("multimodal-embedding-v"):
|
||||
if dimensions is None:
|
||||
dimensions = 1024
|
||||
elif dimensions != 1024:
|
||||
raise ValueError(
|
||||
f"The dimension of model {model_name} must be 1024, "
|
||||
"refer to the official documentation for more details: "
|
||||
f"{path_doc}",
|
||||
)
|
||||
refined_dimensions: int = 1024
|
||||
if dimensions is not None:
|
||||
refined_dimensions = dimensions
|
||||
super().__init__(model_name, refined_dimensions)
|
||||
|
||||
self.api_key = api_key
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
inputs: list[TextBlock | ImageBlock | VideoBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the DashScope multimodal embedding API, which accepts text,
|
||||
image, and video data.
|
||||
|
||||
Args:
|
||||
inputs (`list[TextBlock | ImageBlock | VideoBlock]`):
|
||||
The input data to be embedded. It can be a list of text,
|
||||
image, and video blocks.
|
||||
|
||||
Returns:
|
||||
`EmbeddingResponse`:
|
||||
The embedding response object, which contains the embeddings
|
||||
and usage information.
|
||||
"""
|
||||
# check data type
|
||||
formatted_data = []
|
||||
for _ in inputs:
|
||||
if (
|
||||
not isinstance(_, dict)
|
||||
or "type" not in _
|
||||
or _["type"]
|
||||
not in [
|
||||
"text",
|
||||
"image",
|
||||
"video",
|
||||
]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid data : {_}. It should be a list of "
|
||||
"TextBlock, ImageBlock, or VideoBlock.",
|
||||
)
|
||||
if (
|
||||
_["type"] == "video"
|
||||
and _.get("source", {}).get("type") != "url"
|
||||
):
|
||||
raise ValueError(
|
||||
f"The multimodal embedding API only supports URL input "
|
||||
f"for video data, but got {_}.",
|
||||
)
|
||||
|
||||
if _["type"] == "text":
|
||||
assert "text" in _, (
|
||||
f"Invalid text block: {_}. It should contain a "
|
||||
f"'text' field.",
|
||||
)
|
||||
formatted_data.append({"text": _["text"]})
|
||||
|
||||
elif _["type"] == "video":
|
||||
formatted_data.append({"video": _["source"]["url"]})
|
||||
|
||||
elif (
|
||||
_["type"] == "image"
|
||||
and "source" in _
|
||||
and _["source"].get("type") in ["base64", "url"]
|
||||
):
|
||||
typ = _["source"]["type"]
|
||||
if typ == "base64":
|
||||
formatted_data.append(
|
||||
{
|
||||
"image": f'data:{_["source"]["media_type"]};'
|
||||
f'base64,{_["source"]["data"]}',
|
||||
},
|
||||
)
|
||||
elif typ == "url":
|
||||
formatted_data.append(
|
||||
{"image": _["source"]["url"]},
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid block {_}. It should be a valid TextBlock, "
|
||||
f"ImageBlock, or VideoBlock.",
|
||||
)
|
||||
|
||||
# Handle the batch size limit of the DashScope multimodal embedding API
|
||||
collected_embeddings = []
|
||||
collected_time = 0.0
|
||||
collected_tokens = 0
|
||||
collected_source: Literal["cache", "api"] = "cache"
|
||||
for _ in range(0, len(formatted_data), self.batch_size_limit):
|
||||
batch_data = formatted_data[_ : _ + self.batch_size_limit]
|
||||
batch_kwargs = {
|
||||
"input": batch_data,
|
||||
"model": self.model_name,
|
||||
**kwargs,
|
||||
}
|
||||
res = await self._call_api(batch_kwargs)
|
||||
|
||||
collected_embeddings.extend(res.embeddings)
|
||||
collected_time += res.usage.time
|
||||
if res.usage.tokens:
|
||||
collected_tokens += res.usage.tokens
|
||||
if res.source == "api":
|
||||
collected_source = "api"
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=collected_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=collected_tokens,
|
||||
time=collected_time,
|
||||
),
|
||||
source=collected_source,
|
||||
)
|
||||
|
||||
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
|
||||
"""
|
||||
Call the DashScope multimodal embedding API by the given arguments.
|
||||
"""
|
||||
# Search in cache first
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
import dashscope
|
||||
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
start_time = datetime.now()
|
||||
res = dashscope.MultiModalEmbedding.call(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to get embedding from DashScope API: {res}",
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_["embedding"] for _ in res.output["embeddings"]],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=res.usage.get(
|
||||
"image_tokens",
|
||||
0,
|
||||
)
|
||||
+ res.usage.get(
|
||||
"input_tokens",
|
||||
0,
|
||||
),
|
||||
time=time,
|
||||
),
|
||||
source="api",
|
||||
)
|
||||
45
src/agentscope/embedding/_embedding_base.py
Normal file
45
src/agentscope/embedding/_embedding_base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding model base class."""
|
||||
from typing import Any
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
|
||||
|
||||
class EmbeddingModelBase:
|
||||
"""Base class for embedding models."""
|
||||
|
||||
model_name: str
|
||||
"""The embedding model name"""
|
||||
|
||||
supported_modalities: list[str]
|
||||
"""The supported data modalities, e.g. "text", "image", "video"."""
|
||||
|
||||
dimensions: int
|
||||
"""The dimensions of the embedding vector."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dimensions: int,
|
||||
) -> None:
|
||||
"""Initialize the embedding model base class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`):
|
||||
The dimension of the embedding vector.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.dimensions = dimensions
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the embedding API with the given arguments."""
|
||||
raise NotImplementedError(
|
||||
f"The {self.__class__.__name__} class does not implement "
|
||||
f"the __call__ method.",
|
||||
)
|
||||
32
src/agentscope/embedding/_embedding_response.py
Normal file
32
src/agentscope/embedding/_embedding_response.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding response class."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, List
|
||||
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from .._utils._common import _get_timestamp
|
||||
from .._utils._mixin import DictMixin
|
||||
from ..types import Embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResponse(DictMixin):
|
||||
"""The embedding response class."""
|
||||
|
||||
embeddings: List[Embedding]
|
||||
"""The embedding data"""
|
||||
|
||||
id: str = field(default_factory=lambda: _get_timestamp(True))
|
||||
"""The identity of the embedding response"""
|
||||
|
||||
created_at: str = field(default_factory=_get_timestamp)
|
||||
"""The timestamp of the embedding response creation"""
|
||||
|
||||
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
|
||||
"""The type of the response, must be `embedding`."""
|
||||
|
||||
usage: EmbeddingUsage | None = field(default_factory=lambda: None)
|
||||
"""The usage of the embedding model API invocation, if available."""
|
||||
|
||||
source: Literal["cache", "api"] = field(default_factory=lambda: "api")
|
||||
"""If the response comes from the cache or the API."""
|
||||
20
src/agentscope/embedding/_embedding_usage.py
Normal file
20
src/agentscope/embedding/_embedding_usage.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The embedding usage class in agentscope."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .._utils._mixin import DictMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingUsage(DictMixin):
|
||||
"""The usage of an embedding model API invocation."""
|
||||
|
||||
time: float
|
||||
"""The time used in seconds."""
|
||||
|
||||
tokens: int | None = field(default_factory=lambda: None)
|
||||
"""The number of tokens used, if available."""
|
||||
|
||||
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
|
||||
"""The type of the usage, must be `embedding`."""
|
||||
187
src/agentscope/embedding/_file_cache.py
Normal file
187
src/agentscope/embedding/_file_cache.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A file embedding cache implementation for storing and retrieving
|
||||
embeddings in binary files."""
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from .._logging import logger
|
||||
from ..types import (
|
||||
Embedding,
|
||||
JSONSerializableObject,
|
||||
)
|
||||
|
||||
|
||||
class FileEmbeddingCache(EmbeddingCacheBase):
|
||||
"""The embedding cache class that stores each embeddings vector in
|
||||
binary files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: str = "./.cache/embeddings",
|
||||
max_file_number: int | None = None,
|
||||
max_cache_size: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the file embedding cache class.
|
||||
|
||||
Args:
|
||||
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
|
||||
The directory to store the embedding files.
|
||||
max_file_number (`int | None`, defaults to `None`):
|
||||
The maximum number of files to keep in the cache directory. If
|
||||
exceeded, the oldest files will be removed.
|
||||
max_cache_size (`int | None`, defaults to `None`):
|
||||
The maximum size of the cache directory in MB. If exceeded,
|
||||
the oldest files will be removed until the size is within the
|
||||
limit.
|
||||
"""
|
||||
self._cache_dir = os.path.abspath(cache_dir)
|
||||
self.max_file_number = max_file_number
|
||||
self.max_cache_size = max_cache_size
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> str:
|
||||
"""The cache directory where the embedding files are stored."""
|
||||
if not os.path.exists(self._cache_dir):
|
||||
os.makedirs(self._cache_dir, exist_ok=True)
|
||||
return self._cache_dir
|
||||
|
||||
async def store(
|
||||
self,
|
||||
embeddings: List[Embedding],
|
||||
identifier: JSONSerializableObject,
|
||||
overwrite: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Store the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
embeddings (`List[Embedding]`):
|
||||
The embeddings to store.
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to distinguish the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
overwrite (`bool`, defaults to `False`):
|
||||
Whether to overwrite existing embeddings with the same
|
||||
identifier. If `True`, existing embeddings will be replaced.
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
if not os.path.isfile(path_file):
|
||||
raise RuntimeError(
|
||||
f"Path {path_file} exists but is not a file.",
|
||||
)
|
||||
|
||||
if overwrite:
|
||||
np.save(path_file, embeddings)
|
||||
await self._maintain_cache_dir()
|
||||
else:
|
||||
np.save(path_file, embeddings)
|
||||
await self._maintain_cache_dir()
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
identifier: JSONSerializableObject,
|
||||
) -> List[Embedding] | None:
|
||||
"""Retrieve the embeddings with the given identifier. If not found,
|
||||
return `None`.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifier to retrieve the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
return np.load(os.path.join(self.cache_dir, filename)).tolist()
|
||||
return None
|
||||
|
||||
async def remove(self, identifier: JSONSerializableObject) -> None:
|
||||
"""Remove the embeddings with the given identifier.
|
||||
|
||||
Args:
|
||||
identifier (`JSONSerializableObject`):
|
||||
The identifiers to remove the embeddings, which will be
|
||||
used to generate a hashable filename, so it should be
|
||||
JSON serializable (e.g. a string, number, list, dict).
|
||||
"""
|
||||
filename = self._get_filename(identifier)
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
|
||||
if os.path.exists(path_file):
|
||||
os.remove(path_file)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {path_file} does not exist.")
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear the cache directory by removing all files."""
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith(".npy"):
|
||||
os.remove(os.path.join(self.cache_dir, filename))
|
||||
|
||||
def _get_cache_size(self) -> float:
|
||||
"""Get the current size of the cache directory in MB."""
|
||||
total_size = 0
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith(".npy"):
|
||||
path_file = os.path.join(self.cache_dir, filename)
|
||||
if os.path.isfile(path_file):
|
||||
total_size += os.path.getsize(path_file)
|
||||
return total_size / (1024.0 * 1024.0)
|
||||
|
||||
@staticmethod
|
||||
def _get_filename(identifier: JSONSerializableObject) -> str:
|
||||
"""Generate a filename based on the identifier."""
|
||||
json_str = json.dumps(identifier, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
|
||||
|
||||
async def _maintain_cache_dir(self) -> None:
|
||||
"""Maintain the cache directory by removing old files if the number of
|
||||
files exceeds the maximum limit or if the cache size exceeds the
|
||||
maximum size."""
|
||||
files = [
|
||||
(_.name, _.stat().st_mtime)
|
||||
for _ in os.scandir(self.cache_dir)
|
||||
if _.is_file() and _.name.endswith(".npy")
|
||||
]
|
||||
files.sort(key=lambda x: x[1])
|
||||
|
||||
if self.max_file_number and len(files) > self.max_file_number:
|
||||
for file_name, _ in files[: 0 - self.max_file_number]:
|
||||
os.remove(os.path.join(self.cache_dir, file_name))
|
||||
logger.info(
|
||||
"Remove cached embedding file %s for limited number "
|
||||
"of files (%d).",
|
||||
file_name,
|
||||
self.max_file_number,
|
||||
)
|
||||
files = files[0 - self.max_file_number :]
|
||||
|
||||
if (
|
||||
self.max_cache_size is not None
|
||||
and self._get_cache_size() > self.max_cache_size
|
||||
):
|
||||
removed_files = []
|
||||
for filename, _ in files:
|
||||
os.remove(os.path.join(self.cache_dir, filename))
|
||||
removed_files.append(filename)
|
||||
if self._get_cache_size() <= self.max_cache_size:
|
||||
break
|
||||
|
||||
if removed_files:
|
||||
logger.info(
|
||||
"Remove %d cached embedding file(s) for limited "
|
||||
"cache size (%d MB).",
|
||||
len(removed_files),
|
||||
self.max_cache_size,
|
||||
)
|
||||
109
src/agentscope/embedding/_gemini_embedding.py
Normal file
109
src/agentscope/embedding/_gemini_embedding.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The gemini text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class GeminiTextEmbedding(EmbeddingModelBase):
|
||||
"""The Gemini text embedding model."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 3072,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Gemini text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The Gemini API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 3072):
|
||||
The dimension of the embedding vector, refer to the
|
||||
`official documentation
|
||||
<https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_
|
||||
for more details.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
from google import genai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = genai.Client(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""The Gemini embedding API call.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
|
||||
# TODO: handle the batch size limit
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"contents": gather_text,
|
||||
"config": kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = self.client.models.embed_content(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.values for _ in response.embeddings],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.values for _ in response.embeddings],
|
||||
usage=EmbeddingUsage(
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
106
src/agentscope/embedding/_ollama_embedding.py
Normal file
106
src/agentscope/embedding/_ollama_embedding.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ollama text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Any
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ..embedding import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OllamaTextEmbedding(EmbeddingModelBase):
|
||||
"""The Ollama embedding model."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dimensions: int,
|
||||
host: str | None = None,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Ollama text embedding model class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`):
|
||||
The dimension of the embedding vector, the parameter should be
|
||||
provided according to the model used.
|
||||
host (`str | None`, defaults to `None`):
|
||||
The host URL for the Ollama API.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
"""
|
||||
import ollama
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = ollama.AsyncClient(host=host, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the Ollama embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embed(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=response.embeddings,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=response.embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
109
src/agentscope/embedding/_openai_embedding.py
Normal file
109
src/agentscope/embedding/_openai_embedding.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The OpenAI text embedding model class."""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from ._embedding_response import EmbeddingResponse
|
||||
from ._embedding_usage import EmbeddingUsage
|
||||
from ._cache_base import EmbeddingCacheBase
|
||||
from ._embedding_base import EmbeddingModelBase
|
||||
from ..message import TextBlock
|
||||
|
||||
|
||||
class OpenAITextEmbedding(EmbeddingModelBase):
|
||||
"""OpenAI text embedding model class."""
|
||||
|
||||
supported_modalities: list[str] = ["text"]
|
||||
"""This class only supports text input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
dimensions: int = 1024,
|
||||
embedding_cache: EmbeddingCacheBase | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI text embedding model class.
|
||||
|
||||
Args:
|
||||
api_key (`str`):
|
||||
The OpenAI API key.
|
||||
model_name (`str`):
|
||||
The name of the embedding model.
|
||||
dimensions (`int`, defaults to 1024):
|
||||
The dimension of the embedding vector.
|
||||
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
|
||||
The embedding cache class instance, used to cache the
|
||||
embedding results to avoid repeated API calls.
|
||||
|
||||
# TODO: handle batch size limit and token limit
|
||||
"""
|
||||
import openai
|
||||
|
||||
super().__init__(model_name, dimensions)
|
||||
|
||||
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
|
||||
self.embedding_cache = embedding_cache
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
text: List[str | TextBlock],
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingResponse:
|
||||
"""Call the OpenAI embedding API.
|
||||
|
||||
Args:
|
||||
text (`List[str | TextBlock]`):
|
||||
The input text to be embedded. It can be a list of strings.
|
||||
"""
|
||||
gather_text = []
|
||||
for _ in text:
|
||||
if isinstance(_, dict) and "text" in _:
|
||||
gather_text.append(_["text"])
|
||||
elif isinstance(_, str):
|
||||
gather_text.append(_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input text must be a list of strings or TextBlock dicts.",
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"input": gather_text,
|
||||
"model": self.model_name,
|
||||
"dimensions": self.dimensions,
|
||||
"encoding_format": "float",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.embedding_cache:
|
||||
cached_embeddings = await self.embedding_cache.retrieve(
|
||||
identifier=kwargs,
|
||||
)
|
||||
if cached_embeddings:
|
||||
return EmbeddingResponse(
|
||||
embeddings=cached_embeddings,
|
||||
usage=EmbeddingUsage(
|
||||
tokens=0,
|
||||
time=0,
|
||||
),
|
||||
source="cache",
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
response = await self.client.embeddings.create(**kwargs)
|
||||
time = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
if self.embedding_cache:
|
||||
await self.embedding_cache.store(
|
||||
identifier=kwargs,
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
embeddings=[_.embedding for _ in response.data],
|
||||
usage=EmbeddingUsage(
|
||||
tokens=response.usage.total_tokens,
|
||||
time=time,
|
||||
),
|
||||
)
|
||||
44
src/agentscope/evaluate/__init__.py
Normal file
44
src/agentscope/evaluate/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The evaluation module in AgentScope."""
|
||||
|
||||
from ._evaluator import (
|
||||
EvaluatorBase,
|
||||
RayEvaluator,
|
||||
GeneralEvaluator,
|
||||
)
|
||||
from ._metric_base import (
|
||||
MetricBase,
|
||||
MetricResult,
|
||||
MetricType,
|
||||
)
|
||||
from ._task import Task
|
||||
from ._solution import SolutionOutput
|
||||
from ._benchmark_base import BenchmarkBase
|
||||
from ._evaluator_storage import (
|
||||
EvaluatorStorageBase,
|
||||
FileEvaluatorStorage,
|
||||
)
|
||||
from ._ace_benchmark import (
|
||||
ACEBenchmark,
|
||||
ACEAccuracy,
|
||||
ACEProcessAccuracy,
|
||||
ACEPhone,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BenchmarkBase",
|
||||
"EvaluatorBase",
|
||||
"RayEvaluator",
|
||||
"GeneralEvaluator",
|
||||
"MetricBase",
|
||||
"MetricResult",
|
||||
"MetricType",
|
||||
"EvaluatorStorageBase",
|
||||
"FileEvaluatorStorage",
|
||||
"Task",
|
||||
"SolutionOutput",
|
||||
"ACEBenchmark",
|
||||
"ACEAccuracy",
|
||||
"ACEProcessAccuracy",
|
||||
"ACEPhone",
|
||||
]
|
||||
16
src/agentscope/evaluate/_ace_benchmark/__init__.py
Normal file
16
src/agentscope/evaluate/_ace_benchmark/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ACE benchmark related implementations in AgentScope."""
|
||||
|
||||
from ._ace_benchmark import ACEBenchmark
|
||||
from ._ace_metric import (
|
||||
ACEAccuracy,
|
||||
ACEProcessAccuracy,
|
||||
)
|
||||
from ._ace_tools_zh import ACEPhone
|
||||
|
||||
__all__ = [
|
||||
"ACEBenchmark",
|
||||
"ACEPhone",
|
||||
"ACEAccuracy",
|
||||
"ACEProcessAccuracy",
|
||||
]
|
||||
240
src/agentscope/evaluate/_ace_benchmark/_ace_benchmark.py
Normal file
240
src/agentscope/evaluate/_ace_benchmark/_ace_benchmark.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ACE benchmark class in agentscope. The code is implemented with
|
||||
reference to the `ACEBench <https://github.com/ACEBench/ACEBench>`_
|
||||
under the MIT license."""
|
||||
import json
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import json5
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from ._ace_metric import ACEAccuracy, ACEProcessAccuracy
|
||||
from ._ace_tools_zh import ACEPhone
|
||||
from .._benchmark_base import BenchmarkBase
|
||||
from .._task import Task
|
||||
|
||||
|
||||
class ACEBenchmark(BenchmarkBase):
|
||||
"""The ACE benchmark for evaluating AI agents."""
|
||||
|
||||
data_dir_url: str = (
|
||||
"https://raw.githubusercontent.com/ACEBench/ACEBench/main/data_all"
|
||||
)
|
||||
"""The URL to the data dir"""
|
||||
|
||||
data_subdir: list[str] = [
|
||||
# "data_en", # TODO: enable English version
|
||||
"data_zh",
|
||||
]
|
||||
|
||||
ground_truth_dir: str = "possible_answer"
|
||||
|
||||
data_files: list[str] = [
|
||||
"data_agent_multi_step.json",
|
||||
"data_agent_multi_turn.json",
|
||||
# "data_normal_atom_bool.json",
|
||||
# "data_normal_atom_enum.json",
|
||||
# "data_normal_atom_list.json",
|
||||
# "data_normal_atom_number.json",
|
||||
# "data_normal_atom_object_deep.json",
|
||||
# "data_normal_atom_object_short.json",
|
||||
#
|
||||
# "data_normal_multi_turn_user_adjust.json",
|
||||
# "data_normal_multi_turn_user_switch.json",
|
||||
#
|
||||
# "data_normal_preference.json",
|
||||
# "data_normal_similar_api.json",
|
||||
# "data_normal_single_turn_parallel_function.json",
|
||||
# "data_normal_single_turn_single_function.json",
|
||||
#
|
||||
# "data_special_error_param.json",
|
||||
# "data_special_incomplete.json",
|
||||
# "data_special_irrelevant.json",
|
||||
]
|
||||
"""The data filenames"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
) -> None:
|
||||
"""Initialize the ACEBenchmark
|
||||
|
||||
Args:
|
||||
data_dir (`str`):
|
||||
The directory where the dataset is downloaded and saved.
|
||||
"""
|
||||
super().__init__(
|
||||
name="ACEBench",
|
||||
description="The ACE benchmark for evaluating AI agents.",
|
||||
)
|
||||
|
||||
self.data_dir = os.path.abspath(data_dir)
|
||||
|
||||
if os.path.exists(data_dir) and not os.path.isdir(data_dir):
|
||||
raise RuntimeError(
|
||||
f"The data_dir `{data_dir}` is not a valid directory path.",
|
||||
)
|
||||
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
if not self._verify_data():
|
||||
self._download_data()
|
||||
|
||||
self.dataset = self._load_data()
|
||||
|
||||
def _load_data(self) -> list[dict]:
|
||||
"""Load the dataset from the data directory."""
|
||||
dataset = []
|
||||
for subdir in self.data_subdir:
|
||||
for filename in self.data_files:
|
||||
file_path = os.path.join(self.data_dir, subdir, filename)
|
||||
|
||||
gt_path = os.path.join(
|
||||
self.data_dir,
|
||||
subdir,
|
||||
self.ground_truth_dir,
|
||||
filename,
|
||||
)
|
||||
gt_dataset = {}
|
||||
with open(gt_path, "r", encoding="utf-8") as gt_file:
|
||||
for line in gt_file:
|
||||
gt_data = json5.loads(line)
|
||||
gt_dataset[gt_data["id"]] = gt_data
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json5.loads(line)
|
||||
gt = gt_dataset[data["id"]]
|
||||
gt.pop("id", None)
|
||||
data["ground_truth"] = gt["ground_truth"]
|
||||
data["mile_stone"] = gt["mile_stone"]
|
||||
data["language"] = subdir.rsplit(
|
||||
"_",
|
||||
maxsplit=1,
|
||||
)[-1]
|
||||
data["tags"] = {
|
||||
"language": data["language"],
|
||||
"category": filename.split(
|
||||
".",
|
||||
maxsplit=1,
|
||||
)[0].removeprefix(
|
||||
"data_",
|
||||
),
|
||||
}
|
||||
dataset.append(data)
|
||||
|
||||
return dataset
|
||||
|
||||
def _verify_data(self) -> bool:
|
||||
"""Verify the data completeness and integrity."""
|
||||
for subdir in self.data_subdir:
|
||||
for filename in self.data_files:
|
||||
file_path = os.path.join(self.data_dir, subdir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return False
|
||||
|
||||
gt_path = os.path.join(
|
||||
self.data_dir,
|
||||
subdir,
|
||||
self.ground_truth_dir,
|
||||
filename,
|
||||
)
|
||||
if not os.path.exists(gt_path):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _download_data(self) -> None:
|
||||
"""Download the data from the URL"""
|
||||
for subdir in self.data_subdir:
|
||||
subdir_path = os.path.join(self.data_dir, subdir)
|
||||
subdir_gt_path = os.path.join(subdir_path, self.ground_truth_dir)
|
||||
os.makedirs(subdir_path, exist_ok=True)
|
||||
os.makedirs(subdir_gt_path, exist_ok=True)
|
||||
for filename in tqdm(
|
||||
self.data_files,
|
||||
desc=f"Downloading {subdir}",
|
||||
):
|
||||
response = requests.get(
|
||||
f"{self.data_dir_url}/{subdir}/{filename}",
|
||||
)
|
||||
response.raise_for_status()
|
||||
with open(os.path.join(subdir_path, filename), "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
gt_response = requests.get(
|
||||
f"{self.data_dir_url}/{subdir}/"
|
||||
f"{self.ground_truth_dir}/{filename}",
|
||||
)
|
||||
gt_response.raise_for_status()
|
||||
with open(os.path.join(subdir_gt_path, filename), "wb") as f:
|
||||
f.write(gt_response.content)
|
||||
|
||||
@staticmethod
|
||||
def _data_to_task(item: dict) -> Task:
|
||||
"""Convert a dataset item to a Task object."""
|
||||
# Start the simulated phone and load initial configuration
|
||||
ace_phone = ACEPhone()
|
||||
ace_phone.load_initial_config(item["initial_config"])
|
||||
|
||||
# Obtain tool functions
|
||||
tools: list[tuple] = []
|
||||
for function_schema in item["function"]:
|
||||
name = function_schema["name"]
|
||||
|
||||
# Handle the schema differences
|
||||
formatted_schema = json.loads(
|
||||
json.dumps(
|
||||
function_schema,
|
||||
).replace(
|
||||
'"type": "dict"',
|
||||
'"type": "object"',
|
||||
),
|
||||
)
|
||||
|
||||
tool_function = ace_phone.get_tool_function(name)
|
||||
tools.append(
|
||||
(
|
||||
tool_function,
|
||||
{
|
||||
"type": "function",
|
||||
"function": formatted_schema,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return Task(
|
||||
id=item["id"],
|
||||
input=item["question"],
|
||||
ground_truth={
|
||||
"state": item["ground_truth"],
|
||||
"mile_stone": item.get("mile_stone", []),
|
||||
},
|
||||
tags=item.get("tags", {}),
|
||||
metrics=[
|
||||
ACEAccuracy(item["ground_truth"]),
|
||||
ACEProcessAccuracy(item["mile_stone"]),
|
||||
],
|
||||
metadata={
|
||||
# The phone is used to extract the final state after finishing
|
||||
# the task.
|
||||
"phone": ace_phone,
|
||||
# The provided tools for this task, used to equip the agent
|
||||
"tools": tools,
|
||||
},
|
||||
)
|
||||
|
||||
def __iter__(self) -> Generator[Task, None, None]:
|
||||
"""Iterate over the benchmark."""
|
||||
for item in self.dataset:
|
||||
yield self._data_to_task(item)
|
||||
|
||||
def __getitem__(self, index: int) -> Task:
|
||||
"""Get a task by index."""
|
||||
return self._data_to_task(self.dataset[index])
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the benchmark."""
|
||||
return len(self.dataset)
|
||||
131
src/agentscope/evaluate/_ace_benchmark/_ace_metric.py
Normal file
131
src/agentscope/evaluate/_ace_benchmark/_ace_metric.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ACE benchmark metric implementations in AgentScope."""
|
||||
|
||||
from .._solution import SolutionOutput
|
||||
from .._metric_base import MetricBase, MetricResult, MetricType
|
||||
|
||||
|
||||
class ACEProcessAccuracy(MetricBase):
|
||||
"""The ace benchmark process accuracy metric."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mile_stone: list[str],
|
||||
) -> None:
|
||||
"""Initialize the AceBench process accuracy metric."""
|
||||
super().__init__(
|
||||
name="process_accuracy",
|
||||
metric_type=MetricType.NUMERICAL,
|
||||
description="The AceBench Agent eval process accuracy metric.",
|
||||
)
|
||||
self.mile_stone = mile_stone
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
solution: SolutionOutput,
|
||||
) -> MetricResult:
|
||||
"""Calculate the metric result."""
|
||||
|
||||
# Turn the tool use block sequence into ACEBench format
|
||||
# e.g. func(arg1='dfd', arg2=44)
|
||||
gathered_trajectory = []
|
||||
for tool_call in solution.trajectory:
|
||||
if tool_call.get("type") == "tool_use":
|
||||
function_name = tool_call.get("name")
|
||||
kwargs = tool_call.get("input")
|
||||
|
||||
gathered_kwargs = []
|
||||
for key, value in kwargs.items():
|
||||
if isinstance(value, str):
|
||||
gathered_kwargs.append(
|
||||
f"{key}='{value}'",
|
||||
)
|
||||
|
||||
else:
|
||||
gathered_kwargs.append(
|
||||
f"{key}={value}",
|
||||
)
|
||||
|
||||
kwargs_str = ", ".join(gathered_kwargs)
|
||||
gathered_trajectory.append(
|
||||
f"[{function_name}({kwargs_str})]",
|
||||
)
|
||||
|
||||
for stone in self.mile_stone:
|
||||
if stone not in gathered_trajectory:
|
||||
return MetricResult(
|
||||
name=self.name,
|
||||
result=0,
|
||||
message=f"Error: Missing milestone '{stone}' in "
|
||||
"the given trajectory.",
|
||||
)
|
||||
|
||||
return MetricResult(
|
||||
name=self.name,
|
||||
result=1,
|
||||
message="Success",
|
||||
)
|
||||
|
||||
|
||||
class ACEAccuracy(MetricBase):
|
||||
"""The ace benchmark metric"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: list[dict],
|
||||
) -> None:
|
||||
"""Initialize the _metric object."""
|
||||
super().__init__(
|
||||
"accuracy",
|
||||
MetricType.NUMERICAL,
|
||||
"The AceBench Agent eval accuracy metric.",
|
||||
)
|
||||
self.state = state
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
solution: SolutionOutput,
|
||||
) -> MetricResult:
|
||||
"""Calculate the metric result."""
|
||||
# Check if the solution matches the ground truth
|
||||
if not isinstance(solution.output, list):
|
||||
raise ValueError("Ground truth state must be a list.")
|
||||
|
||||
# Handle the typos in ACEBench dataset
|
||||
gathered_state = {}
|
||||
for item in self.state:
|
||||
for key, value in item.items():
|
||||
if key.endswith("API"):
|
||||
key = key.replace("API", "Api")
|
||||
elif key.endswith("rpi"):
|
||||
key = key.replace("pi", "Api")
|
||||
gathered_state[key] = value
|
||||
|
||||
gathered_output = {}
|
||||
for item in solution.output:
|
||||
for key, value in item.items():
|
||||
gathered_output[key] = value
|
||||
|
||||
if not set(gathered_state.keys()).issubset(gathered_output.keys()):
|
||||
raise ValueError(
|
||||
"Missing keys in solution output compared to state, "
|
||||
f"ground truth keys: {gathered_state.keys()}, "
|
||||
f"solution keys: {gathered_output.keys()}",
|
||||
)
|
||||
|
||||
for key, value in gathered_state.items():
|
||||
if value != gathered_output.get(key):
|
||||
return MetricResult(
|
||||
name=self.name,
|
||||
result=0,
|
||||
message=(
|
||||
f"Error: Mismatch in key '{key}':"
|
||||
f"\n{value}\n{gathered_output.get(key)}"
|
||||
),
|
||||
)
|
||||
|
||||
return MetricResult(
|
||||
name=self.name,
|
||||
result=1,
|
||||
message="Success: All keys match",
|
||||
)
|
||||
@@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The ACEBench simulation tools in AgentScope."""
|
||||
|
||||
from ._message_api import MessageApi
|
||||
from ._travel_api import TravelApi
|
||||
from ._reminder_api import ReminderApi
|
||||
from ._food_platform_api import FoodPlatformApi
|
||||
|
||||
__all__ = [
|
||||
"MessageApi",
|
||||
"TravelApi",
|
||||
"ReminderApi",
|
||||
"FoodPlatformApi",
|
||||
]
|
||||
@@ -0,0 +1,302 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The food platform API in the ACEBench evaluation."""
|
||||
|
||||
from ._shared_state import SharedState
|
||||
|
||||
|
||||
class FoodPlatformApi(SharedState):
|
||||
"""The food platform Api in the ACEBench evaluation."""
|
||||
|
||||
tool_functions: list[str] = [
|
||||
"login_food_platform",
|
||||
"view_logged_in_users",
|
||||
"check_balance",
|
||||
"add_food_delivery_order",
|
||||
"get_products",
|
||||
"view_orders",
|
||||
"search_orders",
|
||||
]
|
||||
|
||||
def __init__(self, shared_state: dict) -> None:
|
||||
super().__init__(shared_state)
|
||||
|
||||
# 设置用户和初始金额
|
||||
self.users: dict = {
|
||||
"Eve": {
|
||||
"user_id": "U100",
|
||||
"password": "password123",
|
||||
"balance": 500.0,
|
||||
},
|
||||
"Frank": {
|
||||
"user_id": "U101",
|
||||
"password": "password456",
|
||||
"balance": 300.0,
|
||||
},
|
||||
"Grace": {
|
||||
"user_id": "U102",
|
||||
"password": "password789",
|
||||
"balance": 150.0,
|
||||
},
|
||||
"Helen": {
|
||||
"user_id": "U103",
|
||||
"password": "password321",
|
||||
"balance": 800.0,
|
||||
},
|
||||
"Isaac": {
|
||||
"user_id": "U104",
|
||||
"password": "password654",
|
||||
"balance": 400.0,
|
||||
},
|
||||
"Jack": {
|
||||
"user_id": "U105",
|
||||
"password": "password654",
|
||||
"balance": 120.0,
|
||||
},
|
||||
}
|
||||
|
||||
# 设置六个商家及其菜单
|
||||
self.merchant_list: dict[str, dict] = {
|
||||
"达美乐": {
|
||||
"merchant_id": "M100",
|
||||
"service_type": "Pizza",
|
||||
"menu": [
|
||||
{"product": "玛格丽特披萨", "price": 68.0},
|
||||
{"product": "超级至尊披萨", "price": 88.0},
|
||||
],
|
||||
},
|
||||
"米村拌饭": {
|
||||
"merchant_id": "M101",
|
||||
"service_type": "Bibimbap",
|
||||
"menu": [
|
||||
{"product": "石锅拌饭", "price": 35.0},
|
||||
{"product": "韩式牛肉拌饭", "price": 45.0},
|
||||
],
|
||||
},
|
||||
"海底捞": {
|
||||
"merchant_id": "M102",
|
||||
"service_type": "Hotpot",
|
||||
"menu": [
|
||||
{"product": "牛肉卷", "price": 68.0},
|
||||
{"product": "海鲜拼盘", "price": 88.0},
|
||||
],
|
||||
},
|
||||
"喜茶": {
|
||||
"merchant_id": "M103",
|
||||
"service_type": "Milk Tea",
|
||||
"menu": [
|
||||
{"product": "芝士奶茶", "price": 25.0},
|
||||
{"product": "四季春奶茶", "price": 22.0},
|
||||
],
|
||||
},
|
||||
"盒马生鲜": {
|
||||
"merchant_id": "M104",
|
||||
"service_type": "Fresh Grocery",
|
||||
"menu": [
|
||||
{"product": "有机蔬菜包", "price": 15.0},
|
||||
{"product": "生鲜大礼包", "price": 99.0},
|
||||
],
|
||||
},
|
||||
"九田家烤肉": {
|
||||
"merchant_id": "M105",
|
||||
"service_type": "BBQ",
|
||||
"menu": [
|
||||
{"product": "韩式烤牛肉", "price": 128.0},
|
||||
{"product": "烤五花肉", "price": 78.0},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# 设置已登录用户列表
|
||||
self.logged_in_users: list[str] = []
|
||||
# 订单列表
|
||||
self.orders: list = []
|
||||
|
||||
def get_state_dict(self) -> dict:
|
||||
"""Get the current state dict of the FoodPlatformApi."""
|
||||
return {
|
||||
"FoodPlatform": {
|
||||
"logged_in_users": self.logged_in_users,
|
||||
"orders": self.orders,
|
||||
"users": self.users,
|
||||
},
|
||||
}
|
||||
|
||||
def login_food_platform(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> dict[str, bool | str]:
|
||||
"""使用用户名和密码登录外卖平台。
|
||||
|
||||
Args:
|
||||
username (`str`):
|
||||
用户的用户名。
|
||||
password (`str`):
|
||||
用户的密码。
|
||||
"""
|
||||
if not self.wifi:
|
||||
return {"status": False, "message": "wifi未打开,无法登录"}
|
||||
if username not in self.users:
|
||||
return {"status": False, "message": "用户不存在"}
|
||||
if self.users[username]["password"] != password:
|
||||
return {"status": False, "message": "密码错误"}
|
||||
|
||||
# 检查是否已经有用户登录
|
||||
if username in self.logged_in_users:
|
||||
return {"status": False, "message": f"{username} 已经登录"}
|
||||
|
||||
# 记录已登录用户
|
||||
self.logged_in_users.append(username)
|
||||
return {"status": True, "message": f"用户{username}登陆成功!"}
|
||||
|
||||
def view_logged_in_users(self) -> dict:
|
||||
"""查看当前所有登录的用户。"""
|
||||
if not self.logged_in_users:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "当前没有登录food platform",
|
||||
}
|
||||
|
||||
return {"status": True, "logged_in_users": self.logged_in_users}
|
||||
|
||||
def check_balance(self, user_name: str) -> float:
|
||||
"""查询指定用户的余额。
|
||||
|
||||
Args:
|
||||
user_name (`str`):
|
||||
用户的用户名。
|
||||
"""
|
||||
if user_name in self.users:
|
||||
return self.users[user_name]["balance"]
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def add_food_delivery_order(
|
||||
self,
|
||||
username: str,
|
||||
merchant_name: str,
|
||||
items: list[dict[str, str | int]],
|
||||
) -> dict[str, bool | str]:
|
||||
"""订外卖
|
||||
|
||||
Args:
|
||||
username (`str`):
|
||||
下订单的用户姓名。
|
||||
merchant_name (`str`):
|
||||
下订单的商家名称。
|
||||
items (`list[dict[str, str | int]]`):
|
||||
订单中商品的列表,每个商品包含名称和数量。
|
||||
"""
|
||||
if username not in self.logged_in_users:
|
||||
return {
|
||||
"status": False,
|
||||
"message": f"用户 {username} 未登录food platform",
|
||||
}
|
||||
|
||||
if merchant_name not in self.merchant_list:
|
||||
return {"status": False, "message": "商家不存在"}
|
||||
|
||||
total_price = 0.0
|
||||
order_items = []
|
||||
|
||||
for item in items:
|
||||
product_name = item.get("product")
|
||||
quantity = item.get("quantity", 1)
|
||||
|
||||
if not isinstance(quantity, int) or quantity <= 0:
|
||||
return {
|
||||
"status": False,
|
||||
"message": f"无效的数量 {quantity} 对于商品 {product_name}",
|
||||
}
|
||||
|
||||
# 查找商品价格
|
||||
product_found = False
|
||||
for product in self.merchant_list[merchant_name]["menu"]:
|
||||
if product["product"] == product_name:
|
||||
total_price += product["price"] * quantity
|
||||
order_items.append(
|
||||
{
|
||||
"product": product_name,
|
||||
"quantity": quantity,
|
||||
"price_per_unit": product["price"],
|
||||
},
|
||||
)
|
||||
product_found = True
|
||||
break
|
||||
if not product_found:
|
||||
return {
|
||||
"status": False,
|
||||
"message": f"商品 {product_name} 不存在于 "
|
||||
f"{merchant_name} 的菜单中",
|
||||
}
|
||||
|
||||
# 检查余额是否足够
|
||||
if total_price >= self.users[username]["balance"]:
|
||||
return {"status": False, "message": "余额不足,无法下单"}
|
||||
|
||||
# 扣除余额并创建订单
|
||||
self.users[username]["balance"] -= total_price
|
||||
order = {
|
||||
"user_name": username,
|
||||
"merchant_name": merchant_name,
|
||||
"items": order_items,
|
||||
"total_price": total_price,
|
||||
}
|
||||
self.orders.append(order)
|
||||
return {
|
||||
"status": True,
|
||||
"message": f"外卖订单成功下单给 {merchant_name}," f"总金额为 {total_price} 元",
|
||||
}
|
||||
|
||||
def get_products(
|
||||
self,
|
||||
merchant_name: str,
|
||||
) -> list[dict[str, str | float]] | dict[str, bool | str]:
|
||||
"""获取特定商家的商品列表。
|
||||
|
||||
Args:
|
||||
merchant_name (`str`):
|
||||
要获取商品的商家名称。
|
||||
"""
|
||||
merchant = self.merchant_list.get(merchant_name)
|
||||
if merchant:
|
||||
return merchant["menu"]
|
||||
else:
|
||||
return {
|
||||
"status": False,
|
||||
"message": f"商家 '{merchant_name}' 不存在",
|
||||
}
|
||||
|
||||
def view_orders(
|
||||
self,
|
||||
user_name: str,
|
||||
) -> dict[str, bool | str | list[dict[str, str | int | float]]]:
|
||||
"""查看用户的所有订单"""
|
||||
user_orders = [
|
||||
order for order in self.orders if order["user_name"] == user_name
|
||||
]
|
||||
|
||||
if not user_orders:
|
||||
return {"status": False, "message": "用户没有订单记录"}
|
||||
|
||||
return {"status": True, "orders": user_orders}
|
||||
|
||||
def search_orders(
|
||||
self,
|
||||
keyword: str,
|
||||
) -> dict[str, bool | str | list[dict[str, str | float]]]:
|
||||
"""根据关键字搜索订单。"""
|
||||
matched_orders = [
|
||||
order
|
||||
for order in self.orders
|
||||
if keyword.lower() in order["merchant_name"].lower()
|
||||
or any(
|
||||
keyword.lower() in item.lower()
|
||||
for item in order.get("items", [])
|
||||
)
|
||||
]
|
||||
|
||||
if not matched_orders:
|
||||
return {"status": False, "message": "没有找到匹配的订单"}
|
||||
|
||||
return {"status": True, "orders": matched_orders}
|
||||
@@ -0,0 +1,340 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The Message API in the ACEBench evaluation."""
|
||||
from datetime import datetime
|
||||
|
||||
from ._shared_state import SharedState
|
||||
|
||||
|
||||
class MessageApi(SharedState):
|
||||
"""The message Api in the ACEBench evaluation."""
|
||||
|
||||
tool_functions: list[str] = [
|
||||
"send_message",
|
||||
"delete_message",
|
||||
"view_messages_between_users",
|
||||
"search_messages",
|
||||
"get_all_message_times_with_ids",
|
||||
"get_latest_message_id",
|
||||
"get_earliest_message_id",
|
||||
]
|
||||
|
||||
def __init__(self, share_state: dict) -> None:
|
||||
"""Initialize the MessageApi with shared state."""
|
||||
super().__init__(share_state)
|
||||
|
||||
# 设置六个用户
|
||||
self.max_capacity = 6
|
||||
self.user_list: dict[str, dict[str, str | int]] = {
|
||||
"Eve": {
|
||||
"user_id": "USR100",
|
||||
"phone_number": "123-456-7890",
|
||||
"occupation": "Software Engineer",
|
||||
},
|
||||
"Frank": {
|
||||
"user_id": "USR101",
|
||||
"phone_number": "234-567-8901",
|
||||
"occupation": "Data Scientist",
|
||||
},
|
||||
"Grace": {
|
||||
"user_id": "USR102",
|
||||
"phone_number": "345-678-9012",
|
||||
"occupation": "Product Manager",
|
||||
},
|
||||
"Helen": {
|
||||
"user_id": "USR103",
|
||||
"phone_number": "456-789-0123",
|
||||
"occupation": "UX Designer",
|
||||
},
|
||||
"Isaac": {
|
||||
"user_id": "USR104",
|
||||
"phone_number": "567-890-1234",
|
||||
"occupation": "DevOps Engineer",
|
||||
},
|
||||
"Jack": {
|
||||
"user_id": "USR105",
|
||||
"phone_number": "678-901-2345",
|
||||
"occupation": "Marketing Specialist",
|
||||
},
|
||||
}
|
||||
|
||||
# 设置六个用户之间的短信记录
|
||||
# 信息1和reminder配合 信息2和food配合
|
||||
self.inbox: dict[int, dict[str, str | int]] = {
|
||||
1: {
|
||||
"sender_id": "USR100",
|
||||
"receiver_id": "USR101",
|
||||
"message": "Hey Frank, don't forget about our meeting on "
|
||||
"2024-06-11 at 4 PM in Conference Room 1.",
|
||||
"time": "2024-06-09",
|
||||
},
|
||||
2: {
|
||||
"sender_id": "USR101",
|
||||
"receiver_id": "USR102",
|
||||
"message": """你能帮我点一个\"玛格丽特披萨\"的外卖吗,商家是达美乐。""",
|
||||
"time": "2024-03-09",
|
||||
},
|
||||
3: {
|
||||
"sender_id": "USR102",
|
||||
"receiver_id": "USR103",
|
||||
"message": "帮我查一些喜茶有哪些奶茶外卖,买一杯便宜些的奶茶。"
|
||||
"买完以后记得回复我,回复的内容是(已经买好了)",
|
||||
"time": "2023-12-05",
|
||||
},
|
||||
4: {
|
||||
"sender_id": "USR103",
|
||||
"receiver_id": "USR102",
|
||||
"message": "No problem Helen, I can assist you.",
|
||||
"time": "2024-09-09",
|
||||
},
|
||||
5: {
|
||||
"sender_id": "USR104",
|
||||
"receiver_id": "USR105",
|
||||
"message": "Isaac, are you available for a call?",
|
||||
"time": "2024-06-06",
|
||||
},
|
||||
6: {
|
||||
"sender_id": "USR105",
|
||||
"receiver_id": "USR104",
|
||||
"message": "Yes Jack, let's do it in 30 minutes.",
|
||||
"time": "2024-01-15",
|
||||
},
|
||||
}
|
||||
|
||||
self.message_id_counter: int = 6
|
||||
|
||||
def get_state_dict(self) -> dict:
|
||||
"""Get the current state dict of the MessageApi."""
|
||||
|
||||
# To avoid the error in ACEBench dataset
|
||||
inbox_state = {}
|
||||
for key, value in self.inbox.items():
|
||||
inbox_state[str(key)] = value
|
||||
|
||||
return {
|
||||
"MessageApi": {
|
||||
"inbox": inbox_state,
|
||||
},
|
||||
}
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
sender_name: str,
|
||||
receiver_name: str,
|
||||
message: str,
|
||||
) -> dict[str, bool | str]:
|
||||
"""将一条消息从一个用户发送给另一个用户。
|
||||
|
||||
Args:
|
||||
sender_name (`str`):
|
||||
发送消息的用户姓名。
|
||||
receiver_name (`str`):
|
||||
接收消息的用户姓名。
|
||||
message (`str`):
|
||||
要发送的消息内容。
|
||||
"""
|
||||
if not self.logged_in:
|
||||
return {"status": False, "message": "device未登录,无法发送短信"}
|
||||
|
||||
if not self.wifi:
|
||||
return {"status": False, "message": "wifi关闭,此时不能发送信息"}
|
||||
|
||||
if len(self.inbox) >= self.max_capacity:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "内存容量不够了,你需要询问user删除哪一条短信。",
|
||||
}
|
||||
|
||||
# 验证发送者和接收者是否存在
|
||||
if (
|
||||
sender_name not in self.user_list
|
||||
or receiver_name not in self.user_list
|
||||
):
|
||||
return {"status": False, "message": "发送者或接收者不存在"}
|
||||
|
||||
sender_id = self.user_list[sender_name]["user_id"]
|
||||
receiver_id = self.user_list[receiver_name]["user_id"]
|
||||
|
||||
# 将短信添加到inbox
|
||||
self.message_id_counter += 1
|
||||
self.inbox[self.message_id_counter] = {
|
||||
"sender_id": sender_id,
|
||||
"receiver_id": receiver_id,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
return {"status": True, "message": f"短信成功发送给{receiver_name}。"}
|
||||
|
||||
def delete_message(self, message_id: int) -> dict[str, bool | str]:
|
||||
"""根据消息 ID 删除一条消息。
|
||||
|
||||
Args:
|
||||
message_id (`int`):
|
||||
要删除的消息的 ID。
|
||||
"""
|
||||
if not self.logged_in:
|
||||
return {"status": False, "message": "device未登录,无法删除短信"}
|
||||
if message_id not in self.inbox:
|
||||
return {"status": False, "message": "短信ID不存在"}
|
||||
|
||||
del self.inbox[message_id]
|
||||
return {"status": True, "message": f"短信ID {message_id} 已成功删除。"}
|
||||
|
||||
def view_messages_between_users(
|
||||
self,
|
||||
sender_name: str,
|
||||
receiver_name: str,
|
||||
) -> dict:
|
||||
"""获取特定用户发送给另一个用户的所有消息。
|
||||
|
||||
Args:
|
||||
sender_name (`str`):
|
||||
发送消息的用户姓名。
|
||||
receiver_name (`str`):
|
||||
接收消息的用户姓名。
|
||||
"""
|
||||
if not self.logged_in:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "device未登录,无法查看短信信息",
|
||||
}
|
||||
|
||||
if sender_name not in self.user_list:
|
||||
return {"status": False, "message": "发送者不存在"}
|
||||
|
||||
if receiver_name not in self.user_list:
|
||||
return {"status": False, "message": "接收者不存在"}
|
||||
|
||||
sender_id = self.user_list[sender_name]["user_id"]
|
||||
receiver_id = self.user_list[receiver_name]["user_id"]
|
||||
messages_between_users = []
|
||||
|
||||
# 遍历 inbox,找出 sender_id 发送给 receiver_id 的短信
|
||||
for msg_id, msg_data in self.inbox.items():
|
||||
if (
|
||||
msg_data["sender_id"] == sender_id
|
||||
and msg_data["receiver_id"] == receiver_id
|
||||
):
|
||||
messages_between_users.append(
|
||||
{
|
||||
"id": msg_id,
|
||||
"sender": sender_name,
|
||||
"receiver": receiver_name,
|
||||
"message": msg_data["message"],
|
||||
},
|
||||
)
|
||||
|
||||
if not messages_between_users:
|
||||
return {"status": False, "message": "没有找到相关的短信记录"}
|
||||
|
||||
return {"status": True, "messages": messages_between_users}
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
user_name: str,
|
||||
keyword: str,
|
||||
) -> dict:
|
||||
"""搜索特定用户消息中包含特定关键字的消息。
|
||||
|
||||
Args:
|
||||
user_name (`str`):
|
||||
要搜索消息的用户姓名。
|
||||
keyword (`str`):
|
||||
要在消息中搜索的关键字。
|
||||
"""
|
||||
if user_name not in self.user_list:
|
||||
return {"status": False, "message": "用户不存在"}
|
||||
|
||||
user_id = self.user_list[user_name]["user_id"]
|
||||
matched_messages = []
|
||||
|
||||
# 遍历 inbox,找到发送或接收中包含关键词的消息
|
||||
for msg_id, msg_data in self.inbox.items():
|
||||
if (
|
||||
user_id in (msg_data["sender_id"], msg_data["receiver_id"])
|
||||
and keyword.lower() in msg_data["message"].lower()
|
||||
):
|
||||
matched_messages.append(
|
||||
{
|
||||
"id": msg_id,
|
||||
"sender_id": msg_data["sender_id"],
|
||||
"receiver_id": msg_data["receiver_id"],
|
||||
"message": msg_data["message"],
|
||||
},
|
||||
)
|
||||
|
||||
if not matched_messages:
|
||||
return {"status": False, "message": "没有找到包含关键词的短信"}
|
||||
|
||||
return {"status": True, "messages": matched_messages}
|
||||
|
||||
def get_all_message_times_with_ids(
|
||||
self,
|
||||
) -> dict:
|
||||
"""获取所有短信的时间以及对应的短信编号。"""
|
||||
if not self.logged_in:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "device未登录,获取所有短信的时间以及对应的短信编号。",
|
||||
}
|
||||
message_times_with_ids = {
|
||||
msg_id: msg_data["time"] for msg_id, msg_data in self.inbox.items()
|
||||
}
|
||||
return message_times_with_ids
|
||||
|
||||
def get_latest_message_id(self) -> dict:
|
||||
"""获取最近发送的消息的 ID。"""
|
||||
if not self.logged_in:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "device未登录,无法获取最新发送的短信ID。",
|
||||
}
|
||||
if not self.inbox:
|
||||
return {"status": False, "message": "短信记录为空"}
|
||||
|
||||
# 遍历所有短信,找出时间最新的短信
|
||||
latest_message_id = None
|
||||
latest_time = None
|
||||
|
||||
for message_id, message_data in self.inbox.items():
|
||||
message_time = datetime.strptime(
|
||||
str(message_data["time"]),
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
if latest_time is None or message_time > latest_time:
|
||||
latest_time = message_time
|
||||
latest_message_id = message_id
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"message": f"最新的短信ID是 {latest_message_id}",
|
||||
"message_id": latest_message_id,
|
||||
}
|
||||
|
||||
def get_earliest_message_id(self) -> dict:
|
||||
"""获取最早发送的消息的 ID。"""
|
||||
if not self.logged_in:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "device未登录,无法获取最早发送的短信ID",
|
||||
}
|
||||
if not self.inbox:
|
||||
return {"status": False, "message": "短信记录为空"}
|
||||
|
||||
# 遍历所有短信,找出时间最早的短信
|
||||
earliest_message_id = None
|
||||
earliest_time = None
|
||||
|
||||
for message_id, message_data in self.inbox.items():
|
||||
message_time = datetime.strptime(
|
||||
str(message_data["time"]),
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
if earliest_time is None or message_time < earliest_time:
|
||||
earliest_time = message_time
|
||||
earliest_message_id = message_id
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"message": f"最早的短信ID是 {earliest_message_id}",
|
||||
"message_id": earliest_message_id,
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The reminder API in ACEBench simulation tools."""
|
||||
from datetime import datetime
|
||||
|
||||
from ._shared_state import SharedState
|
||||
|
||||
|
||||
class ReminderApi(SharedState):
|
||||
"""The reminder Api in the ACEBench evaluation."""
|
||||
|
||||
tool_functions: list[str] = [
|
||||
"view_reminder_by_title",
|
||||
"add_reminder",
|
||||
"delete_reminder",
|
||||
"view_all_reminders",
|
||||
"mark_as_notified",
|
||||
"search_reminders",
|
||||
]
|
||||
|
||||
def __init__(self, share_state: dict) -> None:
|
||||
"""Initialize the Reminder Api in the ACEBench evaluation."""
|
||||
super().__init__(share_state)
|
||||
|
||||
self.max_capacity = 6
|
||||
self.reminder_list: dict[
|
||||
int,
|
||||
dict,
|
||||
] = {
|
||||
1: {
|
||||
"reminder_id": 1001,
|
||||
"title": "Doctor's Appointment",
|
||||
"description": "Visit Dr. Smith for a checkup.",
|
||||
"time": "2024-07-15 09:30",
|
||||
"notified": False,
|
||||
},
|
||||
2: {
|
||||
"reminder_id": 1002,
|
||||
"title": "Team Meeting",
|
||||
"description": "Monthly project review with the team.",
|
||||
"time": "2024-07-17 11:00",
|
||||
"notified": False,
|
||||
},
|
||||
3: {
|
||||
"reminder_id": 1003,
|
||||
"title": "To-do list",
|
||||
"description": '首先帮Frank在"盒马生鲜"点外卖,'
|
||||
'需要定两个"生鲜大礼包",再发短信告诉Frank:'
|
||||
'"购买商品的价格是()元"。要把括号换成实际金额,'
|
||||
"保留一位小数。",
|
||||
"time": "2024-07-16 11:00",
|
||||
"notified": False,
|
||||
},
|
||||
}
|
||||
self.reminder_id_counter: int = 3
|
||||
|
||||
def get_state_dict(self) -> dict:
|
||||
"""Get the current state dict of the ReminderApi."""
|
||||
return {
|
||||
"ReminderApi": {
|
||||
"reminder_list": self.reminder_list,
|
||||
},
|
||||
}
|
||||
|
||||
def _check_capacity(self) -> bool:
|
||||
"""检查备忘录容量是否已满。"""
|
||||
return len(self.reminder_list) >= self.max_capacity
|
||||
|
||||
def view_reminder_by_title(
|
||||
self,
|
||||
title: str,
|
||||
) -> dict[str, str | bool | dict[str, str | bool | datetime]]:
|
||||
"""根据提醒的标题查看特定的提醒。
|
||||
|
||||
Args:
|
||||
title (str): 提醒的标题。
|
||||
|
||||
Returns:
|
||||
dict[str, str | bool | dict[str, str | bool | datetime]]:
|
||||
包含查找状态和提醒详情的字典。
|
||||
"""
|
||||
if not self.logged_in:
|
||||
return {"status": False, "message": "device未登录,无法查看提醒"}
|
||||
for reminder in self.reminder_list.values():
|
||||
if reminder["title"] == title:
|
||||
return {"status": True, "reminder": reminder}
|
||||
|
||||
return {"status": False, "message": f"没有找到标题为 '{title}' 的提醒"}
|
||||
|
||||
def add_reminder(
|
||||
self,
|
||||
title: str,
|
||||
description: str,
|
||||
time: datetime,
|
||||
) -> dict[str, bool | str]:
|
||||
"""添加一个新的提醒。
|
||||
|
||||
Args:
|
||||
title (str): 提醒标题。
|
||||
description (str): 提醒描述。
|
||||
time (datetime): 提醒时间, 一定遵循格式"YYYY-MM-DD HH:MM"。
|
||||
|
||||
Returns:
|
||||
dict[str, bool | str]: 包含添加状态和结果的字典。
|
||||
"""
|
||||
if not self.logged_in:
|
||||
return {
|
||||
"status": False,
|
||||
"message": "device未登录,无法添加一个新的提醒",
|
||||
}
|
||||
if self._check_capacity():
|
||||
return {"status": False, "message": "提醒容量已满,无法添加新的提醒"}
|
||||
|
||||
self.reminder_id_counter += 1
|
||||
reminder_id = self.reminder_id_counter
|
||||
self.reminder_list[reminder_id] = {
|
||||
"reminder_id": reminder_id,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"time": time,
|
||||
"notified": False,
|
||||
}
|
||||
return {"status": True, "message": f"提醒 '{title}' 已成功添加"}
|
||||
|
||||
def delete_reminder(self, reminder_id: int) -> dict[str, bool | str]:
|
||||
"""删除指定的提醒。
|
||||
|
||||
Args:
|
||||
reminder_id (int): 要删除的提醒ID。
|
||||
|
||||
Returns:
|
||||
dict[str, bool | str]: 包含删除状态和结果的字典。
|
||||
"""
|
||||
if not self.logged_in:
|
||||
return {"status": False, "message": "device未登录,无法删除指定的提醒"}
|
||||
if reminder_id not in self.reminder_list:
|
||||
return {"status": False, "message": "提醒ID不存在"}
|
||||
|
||||
del self.reminder_list[reminder_id]
|
||||
return {"status": True, "message": f"提醒ID {reminder_id} 已成功删除"}
|
||||
|
||||
def view_all_reminders(
|
||||
self,
|
||||
) -> dict:
|
||||
"""查看所有的提醒。
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
包含所有提醒的字典列表。
|
||||
"""
|
||||
if not self.reminder_list:
|
||||
return {"status": False, "message": "没有任何提醒"}
|
||||
|
||||
reminders = []
|
||||
for reminder in self.reminder_list.values():
|
||||
reminders.append(
|
||||
{
|
||||
"title": reminder["title"],
|
||||
"description": reminder["description"],
|
||||
"time": reminder["time"],
|
||||
"notified": reminder["notified"],
|
||||
},
|
||||
)
|
||||
return {"status": True, "reminders": reminders}
|
||||
|
||||
def mark_as_notified(
|
||||
self,
|
||||
reminder_id: int,
|
||||
) -> dict[str, bool | str]:
|
||||
"""标记提醒为已通知。
|
||||
|
||||
Args:
|
||||
reminder_id (int): 要标记为已通知的提醒ID。
|
||||
|
||||
Returns:
|
||||
dict[str, bool | str]:: 包含操作结果的字典。
|
||||
"""
|
||||
if reminder_id not in self.reminder_list:
|
||||
return {"status": False, "message": "提醒ID不存在"}
|
||||
|
||||
self.reminder_list[reminder_id]["notified"] = True
|
||||
return {"status": True, "message": f"提醒ID {reminder_id} 已标记为已通知"}
|
||||
|
||||
def search_reminders(
|
||||
self,
|
||||
keyword: str,
|
||||
) -> dict:
|
||||
"""根据关键词搜索提醒。
|
||||
|
||||
Args:
|
||||
keyword (str): 搜索关键词。
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
包含匹配提醒的字典列表。
|
||||
"""
|
||||
matched_reminders = []
|
||||
|
||||
for reminder in self.reminder_list.values():
|
||||
if (
|
||||
keyword.lower() in reminder["title"].lower()
|
||||
or keyword.lower() in reminder["description"].lower()
|
||||
):
|
||||
matched_reminders.append(
|
||||
{
|
||||
"title": reminder["title"],
|
||||
"description": reminder["description"],
|
||||
"time": reminder["time"].strftime("%Y-%m-%d %H:%M"),
|
||||
},
|
||||
)
|
||||
|
||||
if not matched_reminders:
|
||||
return {"status": False, "message": "没有找到包含该关键词的提醒"}
|
||||
|
||||
return {"status": True, "reminders": matched_reminders}
|
||||
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The shared state class for ACEBench simulation tools."""
|
||||
|
||||
|
||||
class SharedState:
|
||||
"""The sharing state class for ACEBench simulation tools."""
|
||||
|
||||
def __init__(self, shared_state: dict) -> None:
|
||||
"""Initialize the shared state"""
|
||||
self._shared_state = shared_state
|
||||
|
||||
@property
|
||||
def wifi(self) -> bool:
|
||||
"""The WI-FI state"""
|
||||
return self._shared_state["wifi"]
|
||||
|
||||
@property
|
||||
def logged_in(self) -> bool:
|
||||
"""The logged in state"""
|
||||
return self._shared_state["logged_in"]
|
||||
@@ -0,0 +1,834 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# type: ignore
|
||||
# pylint: disable=too-many-lines
|
||||
# pylint: disable=too-many-statements
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-statements
|
||||
# pylint: disable=too-many-return-statements
|
||||
"""The travel API for the ACEBench simulation tools in AgentScope."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class TravelApi:
|
||||
"""旅行预订系统类。
|
||||
|
||||
提供航班查询、用户认证、预订管理等功能的旅行系统。
|
||||
支持直飞和中转航班查询、航班预订、预订修改和取消等功能。
|
||||
"""
|
||||
|
||||
tool_functions: list[str] = [
|
||||
"get_user_details",
|
||||
"get_flight_details",
|
||||
"get_reservation_details",
|
||||
"reserve_flight",
|
||||
"cancel_reservation",
|
||||
"modify_flight",
|
||||
]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化旅行系统。
|
||||
|
||||
设置用户档案和航班信息,包含用户信息、航班数据和预订记录。
|
||||
"""
|
||||
# 初始化用户信息
|
||||
self.users = {
|
||||
"user1": {
|
||||
"user_name": "Eve",
|
||||
"password": "password123",
|
||||
"cash_balance": 2000.0,
|
||||
"bank_balance": 50000.0,
|
||||
"membership_level": "regular",
|
||||
},
|
||||
"user2": {
|
||||
"user_name": "Frank",
|
||||
"password": "password456",
|
||||
"cash_balance": 8000.0,
|
||||
"bank_balance": 8000.0,
|
||||
"membership_level": "silver",
|
||||
},
|
||||
"user3": {
|
||||
"user_name": "Grace",
|
||||
"password": "password789",
|
||||
"cash_balance": 1000.0,
|
||||
"bank_balance": 5000.0,
|
||||
"membership_level": "gold",
|
||||
},
|
||||
}
|
||||
|
||||
# 初始化航班信息
|
||||
self.flights = [
|
||||
{
|
||||
"flight_no": "CA1234",
|
||||
"origin": "北京",
|
||||
"destination": "上海",
|
||||
"depart_time": "2024-07-15 08:00:00",
|
||||
"arrival_time": "2024-07-15 10:30:00",
|
||||
"status": "available",
|
||||
"seats_available": 5,
|
||||
"economy_price": 1200,
|
||||
"business_price": 3000,
|
||||
},
|
||||
{
|
||||
"flight_no": "MU5678",
|
||||
"origin": "上海",
|
||||
"destination": "北京",
|
||||
"depart_time": "2024-07-16 09:00:00",
|
||||
"arrival_time": "2024-07-16 11:30:00",
|
||||
"status": "available",
|
||||
"seats_available": 3,
|
||||
"economy_price": 1900,
|
||||
"business_price": 3000,
|
||||
},
|
||||
{
|
||||
"flight_no": "CZ4321",
|
||||
"origin": "上海",
|
||||
"destination": "北京",
|
||||
"depart_time": "2024-07-16 20:00:00",
|
||||
"arrival_time": "2024-07-16 22:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 2500,
|
||||
"business_price": 4000,
|
||||
},
|
||||
{
|
||||
"flight_no": "CZ4352",
|
||||
"origin": "上海",
|
||||
"destination": "北京",
|
||||
"depart_time": "2024-07-17 20:00:00",
|
||||
"arrival_time": "2024-07-17 22:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1600,
|
||||
"business_price": 2500,
|
||||
},
|
||||
{
|
||||
"flight_no": "MU3561",
|
||||
"origin": "北京",
|
||||
"destination": "南京",
|
||||
"depart_time": "2024-07-18 08:00:00",
|
||||
"arrival_time": "2024-07-18 10:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 4000,
|
||||
},
|
||||
{
|
||||
"flight_no": "MU1566",
|
||||
"origin": "北京",
|
||||
"destination": "南京",
|
||||
"depart_time": "2024-07-18 20:00:00",
|
||||
"arrival_time": "2024-07-18 22:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 4000,
|
||||
},
|
||||
{
|
||||
"flight_no": "CZ1765",
|
||||
"origin": "南京",
|
||||
"destination": "深圳",
|
||||
"depart_time": "2024-07-17 20:30:00",
|
||||
"arrival_time": "2024-07-17 22:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 2500,
|
||||
},
|
||||
{
|
||||
"flight_no": "CZ1765",
|
||||
"origin": "南京",
|
||||
"destination": "深圳",
|
||||
"depart_time": "2024-07-18 12:30:00",
|
||||
"arrival_time": "2024-07-18 15:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 2500,
|
||||
},
|
||||
{
|
||||
"flight_no": "MH1765",
|
||||
"origin": "厦门",
|
||||
"destination": "成都",
|
||||
"depart_time": "2024-07-17 12:30:00",
|
||||
"arrival_time": "2024-07-17 15:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 2500,
|
||||
},
|
||||
{
|
||||
"flight_no": "MH2616",
|
||||
"origin": "成都",
|
||||
"destination": "厦门",
|
||||
"depart_time": "2024-07-18 18:30:00",
|
||||
"arrival_time": "2024-07-18 21:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 2500,
|
||||
},
|
||||
{
|
||||
"flight_no": "MH2616",
|
||||
"origin": "成都",
|
||||
"destination": "福州",
|
||||
"depart_time": "2024-07-16 18:30:00",
|
||||
"arrival_time": "2024-07-16 21:00:00",
|
||||
"status": "available",
|
||||
"seats_available": 8,
|
||||
"economy_price": 1500,
|
||||
"business_price": 2500,
|
||||
},
|
||||
]
|
||||
|
||||
# 初始化预订列表
|
||||
self.reservations = [
|
||||
{
|
||||
"reservation_id": "res_1",
|
||||
"user_id": "user1",
|
||||
"flight_no": "CA1234",
|
||||
"payment_method": "bank",
|
||||
"cabin": "经济舱",
|
||||
"baggage": 1,
|
||||
"origin": "北京",
|
||||
"destination": "上海",
|
||||
},
|
||||
{
|
||||
"reservation_id": "res_2",
|
||||
"user_id": "user1",
|
||||
"flight_no": "MU5678",
|
||||
"payment_method": "bank",
|
||||
"cabin": "商务舱",
|
||||
"baggage": 1,
|
||||
"origin": "上海",
|
||||
"destination": "北京",
|
||||
},
|
||||
{
|
||||
"reservation_id": "res_3",
|
||||
"user_id": "user2",
|
||||
"flight_no": "MH1765",
|
||||
"payment_method": "bank",
|
||||
"cabin": "商务舱",
|
||||
"baggage": 1,
|
||||
"origin": "厦门",
|
||||
"destination": "成都",
|
||||
},
|
||||
{
|
||||
"reservation_id": "res_4",
|
||||
"user_id": "user2",
|
||||
"flight_no": "MU2616",
|
||||
"payment_method": "bank",
|
||||
"cabin": "商务舱",
|
||||
"baggage": 1,
|
||||
"origin": "成都",
|
||||
"destination": "厦门",
|
||||
},
|
||||
]
|
||||
|
||||
def get_state_dict(self) -> dict:
|
||||
"""Get the current state dict of the TravelApi."""
|
||||
return {
|
||||
"Travel": {
|
||||
"users": self.users,
|
||||
"reservations": self.reservations,
|
||||
},
|
||||
}
|
||||
|
||||
# 根据出发地和到达地查询航班
|
||||
|
||||
def get_flight_details(
|
||||
self,
|
||||
origin: str = None,
|
||||
destination: str = None,
|
||||
) -> list[dict] | str:
|
||||
"""根据出发地和到达地查询航班的基本信息。
|
||||
|
||||
Args:
|
||||
origin (str, optional): 出发地城市名称。默认为None。
|
||||
destination (str, optional): 目的地城市名称。默认为None。
|
||||
|
||||
Returns:
|
||||
list[dict] | str: 符合条件的航班列表或无航班的提示信息。
|
||||
"""
|
||||
flights = self.flights
|
||||
|
||||
# 过滤出发地
|
||||
if origin:
|
||||
flights = [
|
||||
flight for flight in flights if flight["origin"] == origin
|
||||
]
|
||||
|
||||
# 过滤到达地
|
||||
if destination:
|
||||
flights = [
|
||||
flight
|
||||
for flight in flights
|
||||
if flight["destination"] == destination
|
||||
]
|
||||
if len(flights) == 0:
|
||||
return "没有符合条件的直达航班"
|
||||
# 返回查询结果
|
||||
return [
|
||||
{
|
||||
"flight_no": flight["flight_no"],
|
||||
"origin": flight["origin"],
|
||||
"destination": flight["destination"],
|
||||
"depart_time": flight["depart_time"],
|
||||
"arrival_time": flight["arrival_time"],
|
||||
"status": flight["status"],
|
||||
"seats_available": flight["seats_available"],
|
||||
"economy_price": flight["economy_price"],
|
||||
"business_price": flight["business_price"],
|
||||
}
|
||||
for flight in flights
|
||||
]
|
||||
|
||||
def get_user_details(self, user_id: str, password: str) -> dict:
|
||||
"""根据用户名和密码查询用户信息。
|
||||
|
||||
Args:
|
||||
user_id (str): 用户ID。
|
||||
password (str): 用户密码。
|
||||
|
||||
Returns:
|
||||
dict: 用户信息字典(不包含密码)或错误信息。
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
if user and user["password"] == password:
|
||||
return {
|
||||
key: value for key, value in user.items() if key != "password"
|
||||
}
|
||||
return {"status": "error", "message": "用户名或密码不正确"}
|
||||
|
||||
def get_reservation_details(
|
||||
self,
|
||||
reservation_id: str = None,
|
||||
user_id: str = None,
|
||||
) -> list[dict] | dict:
|
||||
"""根据预订ID或用户ID查询预订信息,包括对应航班的基本信息。
|
||||
|
||||
Args:
|
||||
reservation_id (str, optional): 预订ID。默认为None。
|
||||
user_id (str, optional): 用户ID。默认为None。
|
||||
|
||||
Returns:
|
||||
`list[dict] | dict`:
|
||||
详细预订信息列表或错误信息字典。
|
||||
"""
|
||||
# 根据预订ID或用户ID筛选预订信息
|
||||
if reservation_id:
|
||||
reservations = [
|
||||
reservation
|
||||
for reservation in self.reservations
|
||||
if reservation["reservation_id"] == reservation_id
|
||||
]
|
||||
elif user_id:
|
||||
reservations = [
|
||||
reservation
|
||||
for reservation in self.reservations
|
||||
if reservation["user_id"] == user_id
|
||||
]
|
||||
else:
|
||||
return {"status": "error", "message": "请提供有效的预订ID或用户ID"}
|
||||
|
||||
# 对每个预订,附加航班信息
|
||||
detailed_reservations = []
|
||||
for reservation in reservations:
|
||||
flight_info = next(
|
||||
(
|
||||
flight
|
||||
for flight in self.flights
|
||||
if flight["flight_no"] == reservation["flight_no"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
detailed_reservation = {**reservation, "flight_info": flight_info}
|
||||
detailed_reservations.append(detailed_reservation)
|
||||
|
||||
return detailed_reservations
|
||||
|
||||
def authenticate_user(self, user_id: str, password: str) -> dict:
|
||||
"""验证用户身份。
|
||||
|
||||
Args:
|
||||
user_id (str): 用户ID。
|
||||
password (str): 用户密码。
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
用户信息字典或错误信息字典。
|
||||
"""
|
||||
user = self.users.get(user_id)
|
||||
if user and user["password"] == password:
|
||||
return user
|
||||
return {"status": "error", "message": "用户名或密码不正确"}
|
||||
|
||||
def get_baggage_allowance(
|
||||
self,
|
||||
membership_level: str,
|
||||
cabin_class: str,
|
||||
) -> int:
|
||||
"""获取用户基于会员等级和舱位的免费托运行李限额。
|
||||
|
||||
Args:
|
||||
membership_level (str): 会员等级 ("regular", "silver", "gold")。
|
||||
cabin_class (str): 舱位 ("基础经济舱", "经济舱", "商务舱")。
|
||||
|
||||
Returns:
|
||||
int: 免费托运行李数量。
|
||||
"""
|
||||
allowance = {
|
||||
"regular": {"经济舱": 1, "商务舱": 2},
|
||||
"silver": {"经济舱": 2, "商务舱": 3},
|
||||
"gold": {"经济舱": 3, "商务舱": 3},
|
||||
}
|
||||
return allowance.get(membership_level, {}).get(cabin_class, 0)
|
||||
|
||||
def find_transfer_flights(
|
||||
self,
|
||||
origin_city: str,
|
||||
transfer_city: str,
|
||||
destination_city: str,
|
||||
) -> list[dict] | str:
|
||||
"""查找从出发城市到目的地城市的中转航班。
|
||||
|
||||
确保第一班航班降落时间早于第二班航班起飞时间。
|
||||
|
||||
Args:
|
||||
origin_city (str): 出发城市。
|
||||
transfer_city (str): 中转城市。
|
||||
destination_city (str): 到达城市。
|
||||
|
||||
Returns:
|
||||
list[dict] | str:
|
||||
满足条件的中转航班列表,每个航班包含两段航程的信息,或无航班提示。
|
||||
"""
|
||||
# 获取从出发城市到中转城市的航班
|
||||
first_leg_flights: list[dict] = [
|
||||
flight
|
||||
for flight in self.flights
|
||||
if flight["origin"] == origin_city
|
||||
and flight["destination"] == transfer_city
|
||||
and flight["status"] == "available"
|
||||
]
|
||||
|
||||
# 获取从中转城市到目的地城市的航班
|
||||
second_leg_flights = [
|
||||
flight
|
||||
for flight in self.flights
|
||||
if flight["origin"] == transfer_city
|
||||
and flight["destination"] == destination_city
|
||||
and flight["status"] == "available"
|
||||
]
|
||||
|
||||
# 存储符合条件的中转航班
|
||||
transfer_flights = []
|
||||
|
||||
# 遍历第一段航班和第二段航班,查找符合时间条件的组合
|
||||
for first_flight in first_leg_flights:
|
||||
first_arrival = datetime.strptime(
|
||||
first_flight["arrival_time"],
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
for second_flight in second_leg_flights:
|
||||
second_departure = datetime.strptime(
|
||||
str(second_flight["depart_time"]),
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# 检查第一班航班降落时间早于第二班航班起飞时间
|
||||
if first_arrival < second_departure:
|
||||
transfer_flights.append(
|
||||
{
|
||||
"first_leg": first_flight,
|
||||
"second_leg": second_flight,
|
||||
},
|
||||
)
|
||||
|
||||
# 返回符合条件的中转航班列表
|
||||
if transfer_flights:
|
||||
return transfer_flights
|
||||
else:
|
||||
return "未找到符合条件的中转航班。"
|
||||
|
||||
def calculate_baggage_fee(
|
||||
self,
|
||||
membership_level: str,
|
||||
cabin_class: str,
|
||||
baggage_count: int,
|
||||
) -> float:
|
||||
"""计算行李费用。
|
||||
|
||||
Args:
|
||||
membership_level (str): 会员等级。
|
||||
cabin_class (str): 舱位等级。
|
||||
baggage_count (int): 行李数量。
|
||||
|
||||
Returns:
|
||||
float: 额外行李费用。
|
||||
"""
|
||||
free_baggage = {
|
||||
"regular": {"经济舱": 1, "商务舱": 2},
|
||||
"silver": {"经济舱": 2, "商务舱": 3},
|
||||
"gold": {"经济舱": 3, "商务舱": 3},
|
||||
}
|
||||
free_limit = free_baggage[membership_level][cabin_class]
|
||||
additional_baggage = max(baggage_count - free_limit, 0)
|
||||
return additional_baggage * 50
|
||||
|
||||
def update_balance(
|
||||
self,
|
||||
user: dict,
|
||||
payment_method: str,
|
||||
amount: float,
|
||||
) -> bool:
|
||||
"""更新用户的余额。
|
||||
|
||||
Args:
|
||||
user (dict): 用户信息字典。
|
||||
payment_method (str): 支付方式("cash" 或 "bank")。
|
||||
amount (float): 更新金额(正数表示增加,负数表示减少)。
|
||||
|
||||
Returns:
|
||||
bool: 如果余额充足且更新成功,返回 True,否则返回 False。
|
||||
"""
|
||||
if payment_method == "cash":
|
||||
if user["cash_balance"] + amount < 0:
|
||||
return False # 余额不足
|
||||
user["cash_balance"] += amount
|
||||
elif payment_method == "bank":
|
||||
if user["bank_balance"] + amount < 0:
|
||||
return False # 余额不足
|
||||
user["bank_balance"] += amount
|
||||
return True
|
||||
|
||||
def reserve_flight(
|
||||
self,
|
||||
user_id: str,
|
||||
password: str,
|
||||
flight_no: str,
|
||||
cabin: str,
|
||||
payment_method: str,
|
||||
baggage_count: int,
|
||||
) -> str:
|
||||
"""预订航班。
|
||||
|
||||
Args:
|
||||
user_id (str): 用户ID。
|
||||
password (str): 用户密码。
|
||||
flight_no (str): 航班号。
|
||||
cabin (str): 舱位等级。
|
||||
payment_method (str): 支付方式。
|
||||
baggage_count (int): 行李数量。
|
||||
|
||||
Returns:
|
||||
str: 预订结果信息。
|
||||
"""
|
||||
user = self.authenticate_user(user_id, password)
|
||||
if not user:
|
||||
return "认证失败,请检查用户ID和密码。"
|
||||
|
||||
# 检查航班和座位
|
||||
flight = next(
|
||||
(
|
||||
f
|
||||
for f in self.flights
|
||||
if f["flight_no"] == flight_no and f["status"] == "available"
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# 计算航班价格
|
||||
price: int = (
|
||||
flight["economy_price"]
|
||||
if cabin == "经济舱"
|
||||
else flight["business_price"]
|
||||
)
|
||||
total_cost = price
|
||||
|
||||
# 计算行李费用
|
||||
baggage_fee = self.calculate_baggage_fee(
|
||||
user["membership_level"],
|
||||
cabin,
|
||||
baggage_count,
|
||||
)
|
||||
total_cost += baggage_fee
|
||||
|
||||
# 检查支付方式
|
||||
if payment_method not in ["cash", "bank"]:
|
||||
return "支付方式无效"
|
||||
|
||||
# 更新预定后的余额
|
||||
if payment_method == "cash":
|
||||
if total_cost > self.users.get(user_id)["cash_balance"]:
|
||||
return "cash余额不足,请考虑换一种支付方式"
|
||||
self.users.get(user_id)["cash_balance"] -= total_cost
|
||||
else:
|
||||
if total_cost > self.users.get(user_id)["bank_balance"]:
|
||||
return "bank余额不足,请考虑换一种支付方式"
|
||||
self.users.get(user_id)["bank_balance"] -= total_cost
|
||||
|
||||
# 更新航班信息并生成预订
|
||||
flight["seats_available"] -= 1
|
||||
reservation_id = f"res_{len(self.reservations) + 1}"
|
||||
reservation = {
|
||||
"reservation_id": reservation_id,
|
||||
"user_id": user_id,
|
||||
"flight_no": flight_no,
|
||||
"payment_method": payment_method,
|
||||
"cabin": cabin,
|
||||
"baggage": baggage_count,
|
||||
}
|
||||
self.reservations.append(reservation)
|
||||
|
||||
return f"预订成功,预订号:{reservation_id}," f"总费用:{total_cost}元(包含行李费用)。"
|
||||
|
||||
def modify_flight(
|
||||
self,
|
||||
user_id: str,
|
||||
reservation_id: str,
|
||||
new_flight_no: str = None,
|
||||
new_cabin: str = None,
|
||||
add_baggage: int = 0,
|
||||
new_payment_method: str = None,
|
||||
) -> str:
|
||||
"""修改航班预订,包括更改航班、舱位和行李。
|
||||
|
||||
Args:
|
||||
user_id (str): 用户ID。
|
||||
reservation_id (str): 预订ID。
|
||||
new_flight_no (str, optional): 新的航班号。默认为None。
|
||||
new_cabin (str, optional): 新的舱位。默认为None。
|
||||
add_baggage (int, optional): 新增托运行李的数量。默认为0。
|
||||
new_payment_method (str, optional): 新的付款方式。默认为None。
|
||||
|
||||
Returns:
|
||||
str: 修改结果信息。
|
||||
"""
|
||||
# 获取对应的预订
|
||||
reservation = next(
|
||||
(
|
||||
r
|
||||
for r in self.reservations
|
||||
if r["reservation_id"] == reservation_id
|
||||
and r["user_id"] == user_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not reservation:
|
||||
return "预订未找到或用户ID不匹配。"
|
||||
|
||||
# 检查当前预订的航班信息
|
||||
current_flight = next(
|
||||
(
|
||||
f
|
||||
for f in self.flights
|
||||
if f["flight_no"] == reservation["flight_no"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not current_flight:
|
||||
return "航班信息未找到。"
|
||||
|
||||
# 获取原始支付方式或新提供的支付方式
|
||||
payment_method = (
|
||||
new_payment_method
|
||||
if new_payment_method
|
||||
else reservation["payment_method"]
|
||||
)
|
||||
user = self.users[user_id]
|
||||
if not user:
|
||||
return "用户信息未找到。"
|
||||
|
||||
# 存储处理结果
|
||||
result_messages = []
|
||||
|
||||
if new_flight_no and new_flight_no != reservation["flight_no"]:
|
||||
# 更新航班号(若提供)但必须匹配出发地和目的地
|
||||
new_flight = next(
|
||||
(f for f in self.flights if f["flight_no"] == new_flight_no),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
new_flight
|
||||
and new_flight["origin"] == current_flight["origin"]
|
||||
and new_flight["destination"] == current_flight["destination"]
|
||||
):
|
||||
reservation["flight_no"] = new_flight_no
|
||||
result_messages.append("航班号已更改。")
|
||||
else:
|
||||
return "航班更改失败:新的航班号无效或目的地不匹配。"
|
||||
|
||||
# 更新舱位(若提供)并计算价格差价
|
||||
if new_cabin and new_cabin != reservation.get("cabin"):
|
||||
price_difference = self.calculate_price_difference(
|
||||
current_flight,
|
||||
reservation["cabin"],
|
||||
new_cabin,
|
||||
)
|
||||
reservation["cabin"] = new_cabin
|
||||
if price_difference > 0:
|
||||
# 扣除差价
|
||||
if self.update_balance(
|
||||
user,
|
||||
payment_method,
|
||||
-price_difference,
|
||||
):
|
||||
result_messages.append(
|
||||
f"舱位更改成功。已支付差价: {price_difference}。",
|
||||
)
|
||||
else:
|
||||
result_messages.append("余额不足,无法支付舱位差价。")
|
||||
elif price_difference < 0:
|
||||
# 退款
|
||||
self.update_balance(user, payment_method, -price_difference)
|
||||
result_messages.append(f"舱位更改成功。已退款差价: {-price_difference}。")
|
||||
|
||||
# 增加托运行李,检查免费限额和计算费用
|
||||
if add_baggage > 0:
|
||||
membership = user["membership_level"]
|
||||
max_free_baggage = self.get_baggage_allowance(
|
||||
membership,
|
||||
reservation["cabin"],
|
||||
)
|
||||
current_baggage = reservation.get("baggage", 0)
|
||||
total_baggage = current_baggage + add_baggage
|
||||
extra_baggage = max(0, total_baggage - max_free_baggage)
|
||||
baggage_cost = extra_baggage * 50
|
||||
if baggage_cost > 0:
|
||||
# 扣除行李费用
|
||||
if self.update_balance(user, payment_method, -baggage_cost):
|
||||
result_messages.append(
|
||||
f"行李已增加。需支付额外费用: {baggage_cost}。",
|
||||
)
|
||||
else:
|
||||
result_messages.append("余额不足,无法支付额外行李费用。")
|
||||
reservation["baggage"] = total_baggage
|
||||
|
||||
# 返回最终结果
|
||||
if not result_messages:
|
||||
result_messages.append("修改完成,无需额外费用。")
|
||||
return " ".join(result_messages)
|
||||
|
||||
def cancel_reservation(
|
||||
self,
|
||||
user_id: str,
|
||||
reservation_id: str,
|
||||
reason: str,
|
||||
) -> str:
|
||||
"""取消预订。
|
||||
|
||||
Args:
|
||||
user_id (str): 用户ID。
|
||||
reservation_id (str): 预订ID。
|
||||
reason (str): 取消原因。
|
||||
|
||||
Returns:
|
||||
str: 取消结果信息。
|
||||
"""
|
||||
# 设置默认当前时间为 2024年7月14日早上6点
|
||||
current_time = datetime(2024, 7, 14, 6, 0, 0)
|
||||
|
||||
# 验证用户和预订是否存在
|
||||
user = self.users.get(user_id, None)
|
||||
if not user:
|
||||
return "用户ID无效。"
|
||||
|
||||
reservation = next(
|
||||
(
|
||||
r
|
||||
for r in self.reservations
|
||||
if r["reservation_id"] == reservation_id
|
||||
and r["user_id"] == user_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not reservation:
|
||||
return "预订ID无效或与该用户无关。"
|
||||
|
||||
# 检查航班信息是否存在
|
||||
flight = next(
|
||||
(
|
||||
f
|
||||
for f in self.flights
|
||||
if f["flight_no"] == reservation["flight_no"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not flight:
|
||||
return "航班信息无效。"
|
||||
|
||||
# 检查航班是否已起飞
|
||||
depart_time = datetime.strptime(
|
||||
flight["depart_time"],
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
if current_time > depart_time:
|
||||
return "航段已使用,无法取消。"
|
||||
|
||||
# 计算距离出发时间
|
||||
time_until_departure = depart_time - current_time
|
||||
cancel_fee = 0
|
||||
refund_amount = 0
|
||||
|
||||
# 获取航班价格
|
||||
flight_price = (
|
||||
flight["economy_price"]
|
||||
if reservation["cabin"] == "经济舱"
|
||||
else flight["business_price"]
|
||||
)
|
||||
|
||||
# 取消政策及退款计算
|
||||
if reason == "航空公司取消航班":
|
||||
# 航空公司取消航班,全额退款
|
||||
refund_amount = flight_price
|
||||
self.process_refund(user, refund_amount)
|
||||
return f"航班已取消,您的预订将被免费取消,已退款{refund_amount}元。"
|
||||
|
||||
elif time_until_departure > timedelta(days=1):
|
||||
# 离出发时间超过24小时免费取消
|
||||
refund_amount = flight_price
|
||||
self.process_refund(user, refund_amount)
|
||||
return f"距离出发时间超过24小时,免费取消成功,已退款{refund_amount}元。"
|
||||
|
||||
else:
|
||||
# 若不符合免费取消条件,可根据需求设置取消费
|
||||
cancel_fee = flight_price * 0.1 # 假设取消费为票价的10%
|
||||
refund_amount = flight_price - cancel_fee
|
||||
self.process_refund(user, refund_amount)
|
||||
return f"距离出发时间不足24小时,已扣除取消费{cancel_fee}元,退款{refund_amount}元。"
|
||||
|
||||
def process_refund(self, user: dict, amount: float) -> str:
|
||||
"""将退款金额添加到用户的现金余额中。
|
||||
|
||||
Args:
|
||||
user (dict): 用户信息字典。
|
||||
amount (float): 退款金额。
|
||||
"""
|
||||
user["cash_balance"] += amount
|
||||
return f"已成功处理退款,{user['user_name']}的现金余额增加了{amount}元。"
|
||||
|
||||
def calculate_price_difference(
|
||||
self,
|
||||
flight: dict,
|
||||
old_cabin: str,
|
||||
new_cabin: str,
|
||||
) -> float:
|
||||
"""计算舱位价格差异。
|
||||
|
||||
Args:
|
||||
flight (dict): 航班信息字典。
|
||||
old_cabin (str): 原舱位等级。
|
||||
new_cabin (str): 新舱位等级。
|
||||
|
||||
Returns:
|
||||
float: 价格差异(正数表示需支付差价,负数表示退款)。
|
||||
"""
|
||||
cabin_prices = {
|
||||
"经济舱": flight["economy_price"],
|
||||
"商务舱": flight["business_price"],
|
||||
}
|
||||
old_price = cabin_prices.get(old_cabin, 0)
|
||||
new_price = cabin_prices.get(new_cabin, 0)
|
||||
return new_price - old_price
|
||||
122
src/agentscope/evaluate/_ace_benchmark/_ace_tools_zh.py
Normal file
122
src/agentscope/evaluate/_ace_benchmark/_ace_tools_zh.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The Chinese tools for ACEBench evaluation."""
|
||||
from functools import wraps
|
||||
from typing import Callable, Any
|
||||
|
||||
from ._ace_tools_api import (
|
||||
ReminderApi,
|
||||
FoodPlatformApi,
|
||||
TravelApi,
|
||||
MessageApi,
|
||||
)
|
||||
from ...message import TextBlock
|
||||
from ...tool import ToolResponse
|
||||
|
||||
|
||||
def _tool_function_wrapper(get_tool_function: Callable) -> Callable:
|
||||
"""Wrap the tool function result to be ToolResponse."""
|
||||
|
||||
@wraps(get_tool_function)
|
||||
def wrapper(self: "ACEPhone", name: str) -> Callable:
|
||||
"""Wrap the tool function to return ToolResponse."""
|
||||
tool_function = get_tool_function(self, name)
|
||||
|
||||
@wraps(tool_function)
|
||||
def wrapper_tool_function(*args: Any, **kwargs: Any) -> ToolResponse:
|
||||
"""The wrapped tool function"""
|
||||
res = tool_function(*args, **kwargs)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=str(res),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return wrapper_tool_function
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ACEPhone:
|
||||
"""Simulate a user phone with various apps and functionalities in
|
||||
ACEBench. The code is implemented with reference to the
|
||||
`ACEBench <https://github.com/ACEBench/ACEBench>`_.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the shared state and apps for the ACEPhone."""
|
||||
self._state = {
|
||||
"wifi": False,
|
||||
"logged_in": False,
|
||||
}
|
||||
self._message_app = MessageApi(self._state)
|
||||
self._reminder_app = ReminderApi(self._state)
|
||||
self._food_platform_app = FoodPlatformApi(self._state)
|
||||
self._travel = TravelApi()
|
||||
|
||||
def turn_on_wifi(self) -> dict[str, bool | str]:
|
||||
"""开启WiFi连接。"""
|
||||
self._state["wifi"] = True
|
||||
return {"status": True, "message": "wifi已经打开"}
|
||||
|
||||
def login_device(self) -> dict[str, bool | str]:
|
||||
"""登录设备。"""
|
||||
self._state["logged_in"] = True
|
||||
return {"status": True, "message": "设备已经登录"}
|
||||
|
||||
def load_initial_config(self, initial_config: dict) -> None:
|
||||
"""Load the initial config from the application configuration."""
|
||||
# Empty initial config
|
||||
if len(initial_config) == 0:
|
||||
return
|
||||
|
||||
# Fix the typo in ACEBench by renaming "Baspi" to "BaseApi"
|
||||
if "Baspi" in initial_config:
|
||||
initial_config["BaseApi"] = initial_config.pop("Baspi")
|
||||
|
||||
# Verify state
|
||||
assert (
|
||||
"BaseApi" in initial_config
|
||||
and "wifi" in initial_config["BaseApi"]
|
||||
and "logged_in" in initial_config["BaseApi"]
|
||||
), f"Invalid initial config: {initial_config}"
|
||||
|
||||
self._state["wifi"] = initial_config["BaseApi"]["wifi"]
|
||||
self._state["logged_in"] = initial_config["BaseApi"]["logged_in"]
|
||||
|
||||
def get_current_state(self) -> list[dict]:
|
||||
"""Follow ACEBench to get the current state of the ACEPhone."""
|
||||
return [
|
||||
{"BaseApi": self._state},
|
||||
self._message_app.get_state_dict(),
|
||||
self._reminder_app.get_state_dict(),
|
||||
self._food_platform_app.get_state_dict(),
|
||||
self._travel.get_state_dict(),
|
||||
]
|
||||
|
||||
@_tool_function_wrapper
|
||||
def get_tool_function(self, name: str) -> Callable:
|
||||
"""Get a tool function by name."""
|
||||
if name in [
|
||||
"turn_on_wifi",
|
||||
"login_device",
|
||||
]:
|
||||
return getattr(self, name)
|
||||
|
||||
if name in self._message_app.tool_functions:
|
||||
return getattr(self._message_app, name)
|
||||
|
||||
if name in self._food_platform_app.tool_functions:
|
||||
return getattr(self._food_platform_app, name)
|
||||
|
||||
if name in self._reminder_app.tool_functions:
|
||||
return getattr(self._reminder_app, name)
|
||||
|
||||
if name in self._travel.tool_functions:
|
||||
return getattr(self._travel, name)
|
||||
|
||||
raise ValueError(
|
||||
f"Tool function '{name}' not found in ACEPhone.",
|
||||
)
|
||||
43
src/agentscope/evaluate/_benchmark_base.py
Normal file
43
src/agentscope/evaluate/_benchmark_base.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base class for benchmark evaluation."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator
|
||||
|
||||
from ._task import Task
|
||||
|
||||
|
||||
class BenchmarkBase(ABC):
|
||||
"""The base class for benchmark evaluation."""
|
||||
|
||||
name: str
|
||||
"""The name of the benchmark."""
|
||||
|
||||
description: str
|
||||
"""The description of the benchmark."""
|
||||
|
||||
def __init__(self, name: str, description: str) -> None:
|
||||
"""Initialize the benchmark.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the benchmark.
|
||||
description (`str`):
|
||||
A brief description of the benchmark.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Generator[Task, None, None]:
|
||||
"""Iterate over the benchmark."""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the benchmark."""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index: int) -> Task:
|
||||
"""Get the task at the given index."""
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
12
src/agentscope/evaluate/_evaluator/__init__.py
Normal file
12
src/agentscope/evaluate/_evaluator/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The evaluator module in AgentScope."""
|
||||
|
||||
from ._evaluator_base import EvaluatorBase
|
||||
from ._ray_evaluator import RayEvaluator
|
||||
from ._general_evaluator import GeneralEvaluator
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorBase",
|
||||
"RayEvaluator",
|
||||
"GeneralEvaluator",
|
||||
]
|
||||
304
src/agentscope/evaluate/_evaluator/_evaluator_base.py
Normal file
304
src/agentscope/evaluate/_evaluator/_evaluator_base.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base class for evaluator in evaluation."""
|
||||
import collections
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from dataclasses import asdict
|
||||
from typing import Callable, Coroutine, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from .._solution import SolutionOutput
|
||||
from .._task import Task
|
||||
from .._benchmark_base import BenchmarkBase
|
||||
from .._evaluator_storage import EvaluatorStorageBase
|
||||
from .._metric_base import MetricType
|
||||
from ..._utils._common import _get_timestamp
|
||||
|
||||
|
||||
class EvaluatorBase:
|
||||
"""The class that runs the evaluation process."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
benchmark: BenchmarkBase,
|
||||
n_repeat: int,
|
||||
storage: EvaluatorStorageBase,
|
||||
) -> None:
|
||||
"""Initialize the evaluator.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of this evaluator.
|
||||
benchmark: (`BenchmarkBase`):
|
||||
A benchmark instance inheriting from `BenchmarkBase` that
|
||||
defines the evaluation dataset.
|
||||
n_repeat (`int`):
|
||||
How many times to repeat the evaluation for each task.
|
||||
storage (`EvaluatorStorageBase`):
|
||||
A instance inheriting from the child class of
|
||||
`EvaluatorStorageBase` that supports storing and loading
|
||||
solution output and evaluation results.
|
||||
"""
|
||||
self.name = name
|
||||
self.benchmark = benchmark
|
||||
self.n_repeat = n_repeat
|
||||
self.storage = storage
|
||||
|
||||
@abstractmethod
|
||||
async def run(
|
||||
self,
|
||||
solution: Callable[
|
||||
[Task, Callable],
|
||||
Coroutine[Any, Any, SolutionOutput],
|
||||
],
|
||||
) -> None:
|
||||
"""Run the evaluation and return the results.
|
||||
|
||||
Args:
|
||||
solution (`Callable[[Task, Callable], Coroutine[Any, Any, \
|
||||
SolutionOutput]]`):
|
||||
A async function that takes a `Task` instance and a pre-hook
|
||||
as input and returns a `SolutionOutput` instance.
|
||||
"""
|
||||
|
||||
async def _save_evaluation_meta(self) -> None:
|
||||
"""Save the evaluation meta information."""
|
||||
self.storage.save_evaluation_meta(
|
||||
{
|
||||
"evaluation_name": self.name,
|
||||
"created_at": _get_timestamp(),
|
||||
"total_repeats": self.n_repeat,
|
||||
"benchmark": {
|
||||
"name": self.benchmark.name,
|
||||
"description": self.benchmark.description,
|
||||
"total_tasks": len(self.benchmark),
|
||||
},
|
||||
"schema_version": 1,
|
||||
},
|
||||
)
|
||||
|
||||
async def _save_task_meta(self, task: Task) -> None:
|
||||
"""Save the task meta information.
|
||||
|
||||
Args:
|
||||
task (`Task`):
|
||||
The task instance.
|
||||
"""
|
||||
meta_info = asdict(task)
|
||||
meta_info.pop("metadata")
|
||||
self.storage.save_task_meta(
|
||||
task.id,
|
||||
meta_info,
|
||||
)
|
||||
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
async def aggregate(self) -> None:
|
||||
"""Aggregate the evaluation results and save an overall result."""
|
||||
meta_info: dict = {
|
||||
"total_tasks": len(self.benchmark),
|
||||
"total_repeats": self.n_repeat,
|
||||
"total_stats": {
|
||||
"llm": defaultdict(int),
|
||||
"agent": 0,
|
||||
"tool": defaultdict(int),
|
||||
"embedding": defaultdict(int),
|
||||
"chat_usage": {},
|
||||
},
|
||||
"repeats": {},
|
||||
"schema_version": 1,
|
||||
}
|
||||
|
||||
for repeat_index in range(self.n_repeat):
|
||||
repeat_id = str(repeat_index)
|
||||
current_repeat: dict = {
|
||||
"completed_tasks": 0,
|
||||
"incomplete_tasks": 0,
|
||||
"metrics": {},
|
||||
"completed_ids": [],
|
||||
"incomplete_ids": [],
|
||||
"stats": {
|
||||
"llm": defaultdict(int),
|
||||
"agent": 0,
|
||||
"tool": defaultdict(int),
|
||||
"embedding": defaultdict(int),
|
||||
"chat_usage": {},
|
||||
},
|
||||
}
|
||||
for task in self.benchmark:
|
||||
current_stats = self.storage.get_solution_stats(
|
||||
task.id,
|
||||
repeat_id,
|
||||
)
|
||||
|
||||
# llm
|
||||
for model_name, cnt in current_stats.get("llm", {}).items():
|
||||
current_repeat["stats"]["llm"][model_name] += cnt
|
||||
|
||||
# agent
|
||||
current_repeat["stats"]["agent"] += current_stats.get(
|
||||
"agent",
|
||||
0,
|
||||
)
|
||||
|
||||
# tool
|
||||
for tool_name, cnt in current_stats.get("tool", {}).items():
|
||||
current_repeat["stats"]["tool"][tool_name] += cnt
|
||||
|
||||
# embedding
|
||||
for embedding_model, cnt in current_stats.get(
|
||||
"embedding",
|
||||
{},
|
||||
).items():
|
||||
current_repeat["stats"]["embedding"][
|
||||
embedding_model
|
||||
] += cnt
|
||||
|
||||
# chat usage
|
||||
for model_name, usage in current_stats.get(
|
||||
"chat_usage",
|
||||
{},
|
||||
).items():
|
||||
if model_name not in current_repeat["stats"]["chat_usage"]:
|
||||
current_repeat["stats"]["chat_usage"][
|
||||
model_name
|
||||
] = defaultdict(int)
|
||||
current_repeat["stats"]["chat_usage"][model_name][
|
||||
"input_tokens"
|
||||
] += usage.get("input_tokens", 0)
|
||||
current_repeat["stats"]["chat_usage"][model_name][
|
||||
"output_tokens"
|
||||
] += usage.get("output_tokens", 0)
|
||||
|
||||
for metric in task.metrics:
|
||||
# Create a new dict in aggregated_result
|
||||
if metric.name not in current_repeat["metrics"]:
|
||||
current_repeat["metrics"][metric.name] = {
|
||||
"type": metric.metric_type,
|
||||
"involved_tasks": 0,
|
||||
"completed_tasks": 0,
|
||||
"incomplete_tasks": 0,
|
||||
"aggregation": {},
|
||||
"distribution": collections.defaultdict(list),
|
||||
}
|
||||
|
||||
# Record the submitted task
|
||||
current_repeat["metrics"][metric.name][
|
||||
"involved_tasks"
|
||||
] += 1
|
||||
|
||||
# Not finished
|
||||
if not self.storage.evaluation_result_exists(
|
||||
task.id,
|
||||
repeat_id,
|
||||
metric.name,
|
||||
):
|
||||
if task.id not in current_repeat["incomplete_ids"]:
|
||||
current_repeat["incomplete_tasks"] += 1
|
||||
current_repeat["incomplete_ids"].append(task.id)
|
||||
current_repeat["metrics"][metric.name][
|
||||
"incomplete_tasks"
|
||||
] += 1
|
||||
continue
|
||||
|
||||
if task.id not in current_repeat["completed_ids"]:
|
||||
current_repeat["completed_tasks"] += 1
|
||||
current_repeat["completed_ids"].append(task.id)
|
||||
current_repeat["metrics"][metric.name][
|
||||
"completed_tasks"
|
||||
] += 1
|
||||
|
||||
# Get the evaluation result
|
||||
eval_result = self.storage.get_evaluation_result(
|
||||
task.id,
|
||||
repeat_id,
|
||||
metric.name,
|
||||
)
|
||||
|
||||
# Record the metric result
|
||||
if metric.metric_type == MetricType.CATEGORY:
|
||||
current_repeat["metrics"][metric.name]["distribution"][
|
||||
eval_result.result
|
||||
].append(
|
||||
task.id,
|
||||
)
|
||||
|
||||
elif metric.metric_type == MetricType.NUMERICAL:
|
||||
current_repeat["metrics"][metric.name]["distribution"][
|
||||
task.id
|
||||
] = eval_result.result
|
||||
|
||||
print("Repeat ID:", repeat_id)
|
||||
|
||||
for metric, value in current_repeat["metrics"].items():
|
||||
print("\tMetric:", metric)
|
||||
print("\t\tType:", value["type"])
|
||||
print("\t\tInvolved tasks:", value["involved_tasks"])
|
||||
print("\t\tCompleted tasks:", value["completed_tasks"])
|
||||
print("\t\tIncomplete tasks:", value["incomplete_tasks"])
|
||||
|
||||
if value["type"] == MetricType.CATEGORY:
|
||||
# Count the distribution
|
||||
for category, task_ids in value["distribution"].items():
|
||||
value["aggregation"][category] = (
|
||||
len(task_ids) * 1.0 / value["involved_tasks"]
|
||||
)
|
||||
|
||||
elif value["type"] == MetricType.NUMERICAL:
|
||||
scores = list(value["distribution"].values())
|
||||
value["aggregation"] = {
|
||||
"mean": sum(scores) / value["involved_tasks"],
|
||||
"max": max(scores),
|
||||
"min": min(scores),
|
||||
}
|
||||
|
||||
print(
|
||||
"\t\tAggregation:",
|
||||
json.dumps(
|
||||
value["aggregation"],
|
||||
indent=4,
|
||||
ensure_ascii=False,
|
||||
).replace("\n", "\n\t\t"),
|
||||
)
|
||||
|
||||
meta_info["repeats"][repeat_id] = current_repeat
|
||||
|
||||
# Aggregate total stats
|
||||
repeat_stats = current_repeat["stats"]
|
||||
|
||||
# llm
|
||||
for model_name, cnt in repeat_stats.get("llm", {}).items():
|
||||
meta_info["total_stats"]["llm"][model_name] += cnt
|
||||
|
||||
# agent
|
||||
meta_info["total_stats"]["agent"] += repeat_stats.get("agent", 0)
|
||||
|
||||
# tool
|
||||
for tool_name, cnt in repeat_stats.get("tool", {}).items():
|
||||
meta_info["total_stats"]["tool"][tool_name] += cnt
|
||||
|
||||
# embedding
|
||||
for embedding_model, cnt in repeat_stats.get(
|
||||
"embedding",
|
||||
{},
|
||||
).items():
|
||||
meta_info["total_stats"]["embedding"][embedding_model] += cnt
|
||||
|
||||
# chat usage
|
||||
for model_name, usage in repeat_stats.get(
|
||||
"chat_usage",
|
||||
{},
|
||||
).items():
|
||||
if model_name not in meta_info["total_stats"]["chat_usage"]:
|
||||
meta_info["total_stats"]["chat_usage"][
|
||||
model_name
|
||||
] = defaultdict(int)
|
||||
meta_info["total_stats"]["chat_usage"][model_name][
|
||||
"input_tokens"
|
||||
] += usage.get("input_tokens", 0)
|
||||
meta_info["total_stats"]["chat_usage"][model_name][
|
||||
"output_tokens"
|
||||
] += usage.get("output_tokens", 0)
|
||||
|
||||
# save
|
||||
self.storage.save_aggregation_result(meta_info)
|
||||
178
src/agentscope/evaluate/_evaluator/_general_evaluator.py
Normal file
178
src/agentscope/evaluate/_evaluator/_general_evaluator.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""General evaluator implementation in AgentScope, which is easy to debug
|
||||
compared to the RayEvaluator."""
|
||||
from typing import Callable, Awaitable, Coroutine, Any
|
||||
|
||||
from ._evaluator_base import EvaluatorBase
|
||||
from ._in_memory_exporter import _InMemoryExporter
|
||||
from .._evaluator_storage import EvaluatorStorageBase
|
||||
from .._task import Task
|
||||
from .._solution import SolutionOutput
|
||||
from .._benchmark_base import BenchmarkBase
|
||||
|
||||
|
||||
class GeneralEvaluator(EvaluatorBase):
|
||||
"""The general evaluator that support users to debug their evaluation"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
benchmark: BenchmarkBase,
|
||||
n_repeat: int,
|
||||
storage: EvaluatorStorageBase,
|
||||
n_workers: int,
|
||||
) -> None:
|
||||
"""Initialize the evaluator."""
|
||||
super().__init__(
|
||||
name=name,
|
||||
benchmark=benchmark,
|
||||
n_repeat=n_repeat,
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
assert isinstance(benchmark, BenchmarkBase)
|
||||
|
||||
assert n_repeat >= 1, "n_repeat must be at least 1"
|
||||
|
||||
assert n_workers >= 1, "n_workers must be at least 1"
|
||||
|
||||
self.benchmark = benchmark
|
||||
self.n_repeat = n_repeat
|
||||
self.n_workers = n_workers
|
||||
|
||||
async def run_evaluation(
|
||||
self,
|
||||
task: Task,
|
||||
repeat_id: str,
|
||||
solution_output: SolutionOutput,
|
||||
) -> None:
|
||||
"""Run the evaluation for a task and solution result."""
|
||||
evaluation_results = await task.evaluate(solution_output)
|
||||
# store the evaluation result
|
||||
for result in evaluation_results:
|
||||
self.storage.save_evaluation_result(
|
||||
task_id=task.id,
|
||||
repeat_id=repeat_id,
|
||||
evaluation=result,
|
||||
)
|
||||
|
||||
async def run_solution(
|
||||
self,
|
||||
repeat_id: str,
|
||||
task: Task,
|
||||
solution: Callable[[Task, Callable], Awaitable[SolutionOutput]],
|
||||
) -> None:
|
||||
"""Generate a solution to a task and evaluate."""
|
||||
if self.storage.solution_result_exists(task.id, repeat_id):
|
||||
# Obtain from storage
|
||||
solution_result = self.storage.get_solution_result(
|
||||
task.id,
|
||||
repeat_id,
|
||||
)
|
||||
|
||||
else:
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.context import attach, detach
|
||||
from opentelemetry import baggage
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Set baggage
|
||||
ctx = baggage.set_baggage("task_id", task.id)
|
||||
ctx = baggage.set_baggage("repeat_id", repeat_id, context=ctx)
|
||||
|
||||
# Activate the context
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
with tracer.start_as_current_span(
|
||||
name=f"Solution_{task.id}_{repeat_id}",
|
||||
):
|
||||
from ... import _config
|
||||
|
||||
_config.trace_enabled = True
|
||||
|
||||
# Run the solution
|
||||
solution_result = await solution(
|
||||
task,
|
||||
self.storage.get_agent_pre_print_hook(
|
||||
task.id,
|
||||
repeat_id,
|
||||
),
|
||||
)
|
||||
self.storage.save_solution_result(
|
||||
task.id,
|
||||
repeat_id,
|
||||
solution_result,
|
||||
)
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
# Evaluate the solution with the
|
||||
for metric in task.metrics:
|
||||
if not self.storage.evaluation_result_exists(
|
||||
task.id,
|
||||
repeat_id,
|
||||
metric.name,
|
||||
):
|
||||
await self.run_evaluation(
|
||||
task,
|
||||
repeat_id,
|
||||
solution_result,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
solution: Callable[
|
||||
[Task, Callable],
|
||||
Coroutine[Any, Any, SolutionOutput],
|
||||
],
|
||||
) -> None:
|
||||
"""Run the ray-based distributed and parallel evaluation, and get the
|
||||
results.
|
||||
|
||||
Args:
|
||||
solution (`Callable[[Task, Callable], Coroutine[Any, Any, \
|
||||
SolutionOutput]]`):
|
||||
A async function that takes a `Task` instance and a pre-print
|
||||
hook function as input, returns a `SolutionOutput` instance.
|
||||
"""
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
|
||||
exporter = _InMemoryExporter()
|
||||
span_processor = SimpleSpanProcessor(exporter)
|
||||
|
||||
tracer_provider: TracerProvider = trace.get_tracer_provider()
|
||||
if not isinstance(tracer_provider, TracerProvider):
|
||||
# Create a new tracer provider if not exists
|
||||
tracer_provider = TracerProvider()
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
trace.set_tracer_provider(tracer_provider)
|
||||
|
||||
await self._save_evaluation_meta()
|
||||
|
||||
for task in self.benchmark:
|
||||
await self._save_task_meta(task)
|
||||
|
||||
for repeat_id in range(self.n_repeat):
|
||||
await self.run_solution(
|
||||
str(repeat_id),
|
||||
task,
|
||||
solution,
|
||||
)
|
||||
|
||||
# Save the exporter data
|
||||
if (
|
||||
task.id in exporter.cnt
|
||||
and str(repeat_id) in exporter.cnt[task.id]
|
||||
):
|
||||
self.storage.save_solution_stats(
|
||||
task.id,
|
||||
str(repeat_id),
|
||||
exporter.cnt[task.id][str(repeat_id)],
|
||||
)
|
||||
|
||||
await self.aggregate()
|
||||
106
src/agentscope/evaluate/_evaluator/_in_memory_exporter.py
Normal file
106
src/agentscope/evaluate/_evaluator/_in_memory_exporter.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""An in memory exporter of OpenTelemetry traces for AgentScope evaluator, used
|
||||
to record the token usage during evaluation."""
|
||||
from collections import defaultdict
|
||||
from typing import Sequence
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
|
||||
|
||||
from ...tracing._attributes import SpanAttributes, OperationNameValues
|
||||
|
||||
|
||||
class _InMemoryExporter(SpanExporter):
|
||||
"""An in memory exporter to store the token usage from the ChatModel spans
|
||||
in OpenTelemetry traces."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in memory exporter."""
|
||||
# Initialize the counter
|
||||
self.cnt: dict = {}
|
||||
self._stopped = False
|
||||
|
||||
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
|
||||
"""Exports a batch of telemetry data.
|
||||
|
||||
Args:
|
||||
spans (`Sequence[ReadableSpan]`):
|
||||
The list of `opentelemetry.trace.Span` objects to be exported
|
||||
|
||||
Returns:
|
||||
`SpanExportResult`:
|
||||
The result of the export
|
||||
"""
|
||||
for span in spans:
|
||||
task_id = baggage.get_baggage("task_id")
|
||||
repeat_id = baggage.get_baggage("repeat_id")
|
||||
|
||||
if task_id is None or repeat_id is None:
|
||||
continue
|
||||
|
||||
if task_id not in self.cnt:
|
||||
self.cnt[task_id] = {}
|
||||
|
||||
if repeat_id not in self.cnt[task_id]:
|
||||
self.cnt[task_id][repeat_id] = {
|
||||
"llm": defaultdict(int),
|
||||
"agent": 0,
|
||||
"tool": defaultdict(int),
|
||||
"embedding": defaultdict(int),
|
||||
"chat_usage": {},
|
||||
}
|
||||
|
||||
span_kind = span.attributes.get(
|
||||
SpanAttributes.GEN_AI_OPERATION_NAME,
|
||||
)
|
||||
if span_kind == OperationNameValues.CHAT:
|
||||
model_name = span.attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_MODEL,
|
||||
"unknown",
|
||||
)
|
||||
self.cnt[task_id][repeat_id]["llm"][model_name] += 1
|
||||
if (
|
||||
model_name
|
||||
not in self.cnt[task_id][repeat_id]["chat_usage"]
|
||||
):
|
||||
self.cnt[task_id][repeat_id]["chat_usage"][
|
||||
model_name
|
||||
] = defaultdict(int)
|
||||
|
||||
self.cnt[task_id][repeat_id]["chat_usage"][model_name][
|
||||
"input_tokens"
|
||||
] += span.attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS,
|
||||
0,
|
||||
)
|
||||
|
||||
self.cnt[task_id][repeat_id]["chat_usage"][model_name][
|
||||
"output_tokens"
|
||||
] += span.attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
0,
|
||||
)
|
||||
|
||||
elif span_kind == OperationNameValues.INVOKE_AGENT:
|
||||
self.cnt[task_id][repeat_id]["agent"] += 1
|
||||
|
||||
elif span_kind == OperationNameValues.EXECUTE_TOOL:
|
||||
tool_name = span.attributes.get(
|
||||
SpanAttributes.GEN_AI_TOOL_NAME,
|
||||
"unknown",
|
||||
)
|
||||
self.cnt[task_id][repeat_id]["tool"][tool_name] += 1
|
||||
|
||||
elif span_kind == OperationNameValues.EMBEDDINGS:
|
||||
embedding_model = span.attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_MODEL,
|
||||
"unknown",
|
||||
)
|
||||
self.cnt[task_id][repeat_id]["embedding"][embedding_model] += 1
|
||||
|
||||
return SpanExportResult.SUCCESS
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shuts down the exporter."""
|
||||
self._stopped = True
|
||||
267
src/agentscope/evaluate/_evaluator/_ray_evaluator.py
Normal file
267
src/agentscope/evaluate/_evaluator/_ray_evaluator.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The evaluator base class in agentscope."""
|
||||
import asyncio
|
||||
from typing import Callable, Awaitable, Coroutine, Any
|
||||
|
||||
from ._in_memory_exporter import _InMemoryExporter
|
||||
from .._benchmark_base import BenchmarkBase
|
||||
from .._evaluator._evaluator_base import EvaluatorBase
|
||||
from .._solution import SolutionOutput
|
||||
from .._task import Task
|
||||
from .._evaluator_storage import EvaluatorStorageBase
|
||||
|
||||
|
||||
def _check_ray_available() -> None:
|
||||
"""Check if ray is available and raise ImportError if not."""
|
||||
try:
|
||||
import ray # noqa # pylint: disable=unused-import
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install it with `pip install ray` "
|
||||
"to use the RayEvaluator.",
|
||||
) from e
|
||||
|
||||
|
||||
# Create a conditional decorator for ray.remote
|
||||
def _ray_remote_decorator(cls: Any) -> Any:
|
||||
"""
|
||||
Conditional ray.remote decorator that only applies when ray is available.
|
||||
"""
|
||||
try:
|
||||
import ray
|
||||
|
||||
return ray.remote(cls)
|
||||
except ImportError:
|
||||
return cls
|
||||
|
||||
|
||||
@_ray_remote_decorator
|
||||
class RayEvaluationActor:
|
||||
"""
|
||||
Actor class for running evaluation with ray remote.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def run(
|
||||
storage: EvaluatorStorageBase,
|
||||
task: Task,
|
||||
repeat_id: str,
|
||||
solution_output: SolutionOutput,
|
||||
) -> None:
|
||||
"""
|
||||
Run the evaluation for a task and solution result.
|
||||
|
||||
Args:
|
||||
storage (EvaluatorStorageBase): Evaluator storage.
|
||||
task (Task): Task to be evaluated.
|
||||
repeat_id (str): Repeat ID
|
||||
solution_output (SolutionOutput): output data after execute agents.
|
||||
"""
|
||||
evaluation_results = await task.evaluate(solution_output)
|
||||
# store the evaluation result
|
||||
for result in evaluation_results:
|
||||
storage.save_evaluation_result(
|
||||
task_id=task.id,
|
||||
repeat_id=repeat_id,
|
||||
evaluation=result,
|
||||
)
|
||||
|
||||
|
||||
@_ray_remote_decorator
|
||||
class RaySolutionActor:
|
||||
"""
|
||||
Actor class for running agent solutions with ray remote.
|
||||
"""
|
||||
|
||||
def __init__(self, n_workers: int = 1):
|
||||
self.eval_actor = RayEvaluationActor.options(
|
||||
max_concurrency=n_workers,
|
||||
).remote()
|
||||
|
||||
# Set up global exporter for this Actor
|
||||
self.exporter = _InMemoryExporter()
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
|
||||
span_processor = SimpleSpanProcessor(self.exporter)
|
||||
tracer_provider: TracerProvider = trace.get_tracer_provider()
|
||||
if not isinstance(tracer_provider, TracerProvider):
|
||||
# Create a new tracer provider if not exists
|
||||
tracer_provider = TracerProvider()
|
||||
tracer_provider.add_span_processor(span_processor)
|
||||
trace.set_tracer_provider(tracer_provider)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
storage: EvaluatorStorageBase,
|
||||
repeat_id: str,
|
||||
task: Task,
|
||||
solution: Callable[
|
||||
[Task, Callable],
|
||||
Coroutine[Any, Any, SolutionOutput],
|
||||
],
|
||||
) -> None:
|
||||
"""Generate a solution to a task and evaluate.
|
||||
|
||||
Args:
|
||||
storage (EvaluatorStorageBase): Evaluator storage.
|
||||
repeat_id (str): Repeat ID.
|
||||
task (Task): Task to be evaluated.
|
||||
solution
|
||||
(Callable[[Task, Callable], Awaitable[SolutionOutput, Any]]):
|
||||
callable function to execute agents and generate results.
|
||||
"""
|
||||
if storage.solution_result_exists(task.id, repeat_id):
|
||||
# Obtain from storage
|
||||
solution_result = storage.get_solution_result(
|
||||
task.id,
|
||||
repeat_id,
|
||||
)
|
||||
|
||||
else:
|
||||
from opentelemetry import trace, baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Set baggage items
|
||||
ctx = baggage.set_baggage("task_id", task.id)
|
||||
ctx = baggage.set_baggage("repeat_id", repeat_id, context=ctx)
|
||||
|
||||
# Attach the context with baggage
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
with tracer.start_as_current_span(
|
||||
name=f"Solution_{task.id}_{repeat_id}",
|
||||
):
|
||||
from ... import _config
|
||||
|
||||
_config.trace_enabled = True
|
||||
|
||||
# Run the solution
|
||||
solution_result = await solution(
|
||||
task,
|
||||
storage.get_agent_pre_print_hook(
|
||||
task.id,
|
||||
repeat_id,
|
||||
),
|
||||
)
|
||||
finally:
|
||||
detach(token)
|
||||
# Ensure all spans are flushed
|
||||
trace.get_tracer_provider().force_flush()
|
||||
|
||||
storage.save_solution_stats(
|
||||
task.id,
|
||||
repeat_id,
|
||||
self.exporter.cnt.get(task.id, {}).get(repeat_id, {}),
|
||||
)
|
||||
|
||||
storage.save_solution_result(
|
||||
task.id,
|
||||
repeat_id,
|
||||
solution_result,
|
||||
)
|
||||
|
||||
# Evaluate the solution with the metrics
|
||||
futures = []
|
||||
for metric in task.metrics:
|
||||
if not storage.evaluation_result_exists(
|
||||
task.id,
|
||||
repeat_id,
|
||||
metric.name,
|
||||
):
|
||||
futures.append(
|
||||
self.eval_actor.run.remote(
|
||||
storage,
|
||||
task,
|
||||
repeat_id,
|
||||
solution_result,
|
||||
),
|
||||
)
|
||||
if futures:
|
||||
await asyncio.gather(*futures)
|
||||
|
||||
|
||||
class RayEvaluator(EvaluatorBase):
|
||||
"""The ray-based evaluator that supports distributed and parallel
|
||||
evaluation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
benchmark: BenchmarkBase,
|
||||
n_repeat: int,
|
||||
storage: EvaluatorStorageBase,
|
||||
n_workers: int,
|
||||
) -> None:
|
||||
"""Initialize the evaluator."""
|
||||
super().__init__(
|
||||
name=name,
|
||||
benchmark=benchmark,
|
||||
n_repeat=n_repeat,
|
||||
storage=storage,
|
||||
)
|
||||
|
||||
# Check ray availability early
|
||||
_check_ray_available()
|
||||
|
||||
assert isinstance(benchmark, BenchmarkBase)
|
||||
|
||||
assert n_repeat >= 1, "n_repeat must be at least 1"
|
||||
|
||||
assert n_workers >= 1, "n_workers must be at least 1"
|
||||
|
||||
self.benchmark = benchmark
|
||||
self.n_repeat = n_repeat
|
||||
self.n_workers = n_workers
|
||||
|
||||
async def run(
|
||||
self,
|
||||
solution: Callable[
|
||||
[Task, Callable],
|
||||
Awaitable[SolutionOutput] | SolutionOutput,
|
||||
],
|
||||
) -> None:
|
||||
"""Run the ray-based distributed and parallel evaluation, and get the
|
||||
results.
|
||||
|
||||
Args:
|
||||
solution (`Callable[[Task], SolutionOutput]`):
|
||||
A sync or async function that takes a `Task` instance as input
|
||||
and returns a `SolutionOutput` instance.
|
||||
"""
|
||||
|
||||
await self._save_evaluation_meta()
|
||||
|
||||
# Create solution actors
|
||||
futures = []
|
||||
solution_actor = RaySolutionActor.options(
|
||||
max_concurrency=self.n_workers,
|
||||
).remote(n_workers=self.n_workers)
|
||||
|
||||
# Iterate over all tasks in the benchmark
|
||||
for task in self.benchmark:
|
||||
# Save the task meta information
|
||||
await self._save_task_meta(task)
|
||||
|
||||
# Run n_repeat times
|
||||
for repeat_id in range(self.n_repeat):
|
||||
futures.append(
|
||||
solution_actor.run.remote(
|
||||
self.storage,
|
||||
str(repeat_id),
|
||||
task,
|
||||
solution,
|
||||
),
|
||||
)
|
||||
|
||||
# Await all the futures
|
||||
if futures:
|
||||
await asyncio.gather(*futures)
|
||||
|
||||
# Aggregate the results
|
||||
await self.aggregate()
|
||||
10
src/agentscope/evaluate/_evaluator_storage/__init__.py
Normal file
10
src/agentscope/evaluate/_evaluator_storage/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The evaluator storage module in AgentScope."""
|
||||
|
||||
from ._evaluator_storage_base import EvaluatorStorageBase
|
||||
from ._file_evaluator_storage import FileEvaluatorStorage
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorStorageBase",
|
||||
"FileEvaluatorStorage",
|
||||
]
|
||||
@@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The evaluator storage base class for storing solution and evaluation
|
||||
results."""
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable
|
||||
|
||||
from .._metric_base import MetricResult
|
||||
from .._solution import SolutionOutput
|
||||
from ...agent import AgentBase
|
||||
from ...types import JSONSerializableObject
|
||||
|
||||
|
||||
class EvaluatorStorageBase:
|
||||
"""Used to store the solution results and evaluation results to support
|
||||
resuming the evaluation process"""
|
||||
|
||||
@abstractmethod
|
||||
def save_solution_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
output: SolutionOutput,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the solution result.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
output (`SolutionOutput`):
|
||||
The solution output to be saved.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_evaluation_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
metric_name: str,
|
||||
) -> MetricResult:
|
||||
"""Get the evaluation result by the given task id and repeat id
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
metric_name (`str`):
|
||||
The metric name.
|
||||
|
||||
Returns:
|
||||
`MetricResult`:
|
||||
The evaluation result for the given task and repeat ID.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_evaluation_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
evaluation: MetricResult,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the evaluation result.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
evaluation (`MetricResult`):
|
||||
The evaluation result to be saved.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_solution_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
**kwargs: Any,
|
||||
) -> SolutionOutput:
|
||||
"""Get the solution result for the given task and repeat id.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`SolutionOutput`:
|
||||
The solution output for the given task and repeat ID.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def solution_result_exists(self, task_id: str, repeat_id: str) -> bool:
|
||||
"""Check if the solution for the given task and repeat is finished.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
True if the solution result file exists, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def evaluation_result_exists(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
metric_name: str,
|
||||
) -> bool:
|
||||
"""Check if the evaluation result for the given solution and metric
|
||||
is finished.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
metric_name (`str`):
|
||||
The name of the metric.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
True if the evaluation result file exists, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_aggregation_result(
|
||||
self,
|
||||
aggregation_result: dict,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the aggregation result.
|
||||
|
||||
Args:
|
||||
aggregation_result (`dict`):
|
||||
A dictionary containing the aggregation result.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def aggregation_result_exists(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Check if the aggregation result exists
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if the aggregation result file exists.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_evaluation_meta(self, meta_info: dict) -> None:
|
||||
"""Save the evaluation meta information.
|
||||
|
||||
Args:
|
||||
meta_info (`dict`):
|
||||
A dictionary containing the meta information.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_task_meta(
|
||||
self,
|
||||
task_id: str,
|
||||
meta_info: dict[str, JSONSerializableObject],
|
||||
) -> None:
|
||||
"""Save the task meta information.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
meta_info (`dict[str, JSONSerializableObject]`):
|
||||
The task meta information to be saved, which should be JSON
|
||||
serializable.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_solution_stats(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
stats: dict,
|
||||
) -> None:
|
||||
"""Save the solution statistics information for a given task and
|
||||
repeat ID.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
stats (`dict`):
|
||||
A dictionary containing the solution statistics to be saved.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_solution_stats(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
) -> dict:
|
||||
"""Get the solution statistics information for a given task and
|
||||
repeat ID.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
A dictionary containing the solution statistics for the given
|
||||
task and repeat ID.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_pre_print_hook(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
) -> Callable[[AgentBase, dict], None]:
|
||||
"""Get a pre-print hook function for the agent to save the agent
|
||||
printing in the evaluation storage.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`Callable[[AgentBase, dict], None]`:
|
||||
A hook function that takes an `AgentBase` instance and a
|
||||
keyword arguments dictionary as input, saving the agent's
|
||||
printing Msg into the evaluation storage.
|
||||
"""
|
||||
@@ -0,0 +1,460 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A file system based evaluator storage."""
|
||||
import json
|
||||
import os
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Callable
|
||||
|
||||
from ._evaluator_storage_base import EvaluatorStorageBase
|
||||
from .._solution import SolutionOutput
|
||||
from .._metric_base import MetricResult
|
||||
from ...agent import AgentBase
|
||||
from ...message import Msg
|
||||
from ...types import JSONSerializableObject
|
||||
|
||||
|
||||
class FileEvaluatorStorage(EvaluatorStorageBase):
|
||||
"""File system based evaluator storage, providing methods to save and
|
||||
retrieve evaluation results. So that the evaluation process can be resumed
|
||||
from the last saved state.
|
||||
|
||||
The files are organized in a directory structure:
|
||||
- save_dir/
|
||||
- evaluation_result.json
|
||||
- evaluation_meta.json
|
||||
- {repeat_id}/
|
||||
- {task_id}/
|
||||
- solution.json
|
||||
- evaluation/
|
||||
- {metric_name}.json
|
||||
"""
|
||||
|
||||
SOLUTION_FILE_NAME = "solution.json"
|
||||
SOLUTION_STATS_FILE_NAME = "stats.json"
|
||||
EVALUATION_DIR_NAME = "evaluation"
|
||||
EVALUATION_RESULT_FILE = "evaluation_result.json"
|
||||
EVALUATION_META_FILE = "evaluation_meta.json"
|
||||
TASK_META_FILE = "task_meta.json"
|
||||
AGENT_PRINTING_LOG = "logging.txt"
|
||||
|
||||
def __init__(self, save_dir: str) -> None:
|
||||
"""Initialize the file evaluator storage."""
|
||||
self.save_dir = os.path.abspath(save_dir)
|
||||
|
||||
def _get_save_path(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str | None,
|
||||
*args: str,
|
||||
) -> str:
|
||||
"""Get the save path for a given task, repeat ID, and additional path
|
||||
components.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str | None`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation. If None, it will be ignored in the path.
|
||||
*args (`str`):
|
||||
Additional path components to be appended.
|
||||
"""
|
||||
path_components = [
|
||||
task_id,
|
||||
repeat_id,
|
||||
*args,
|
||||
]
|
||||
|
||||
path = os.path.join(self.save_dir, *[_ for _ in path_components if _])
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
||||
|
||||
def save_solution_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
output: SolutionOutput,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the solution result.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
output (`SolutionOutput`):
|
||||
The solution output to be saved.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.SOLUTION_FILE_NAME,
|
||||
)
|
||||
with open(path_file, "w", encoding="utf-8") as f:
|
||||
json.dump(output, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def save_evaluation_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
evaluation: MetricResult,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the evaluation result.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
evaluation (`MetricResult`):
|
||||
The evaluation result to be saved.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.EVALUATION_DIR_NAME,
|
||||
f"{evaluation.name}.json",
|
||||
)
|
||||
with open(path_file, "w", encoding="utf-8") as f:
|
||||
json.dump(evaluation, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def get_evaluation_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
metric_name: str,
|
||||
) -> MetricResult:
|
||||
"""Get the evaluation result by the given task id and repeat id
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
metric_name (`str`):
|
||||
The metric name.
|
||||
|
||||
Returns:
|
||||
`MetricResult`:
|
||||
The evaluation result for the given task and repeat ID.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.EVALUATION_DIR_NAME,
|
||||
f"{metric_name}.json",
|
||||
)
|
||||
if not os.path.exists(path_file):
|
||||
raise FileNotFoundError(path_file)
|
||||
with open(path_file, "r", encoding="utf-8") as f:
|
||||
evaluation = json.load(f)
|
||||
return MetricResult(**evaluation)
|
||||
|
||||
def get_solution_result(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
**kwargs: Any,
|
||||
) -> SolutionOutput:
|
||||
"""Get the solution result for the given task and repeat id from the
|
||||
file system.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Raises:
|
||||
`FileNotFoundError`:
|
||||
If the solution result file does not exist for the given task
|
||||
and repeat ID.
|
||||
|
||||
Returns:
|
||||
`SolutionOutput`:
|
||||
The solution output for the given task and repeat ID.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.SOLUTION_FILE_NAME,
|
||||
)
|
||||
if not os.path.exists(path_file):
|
||||
raise FileNotFoundError(
|
||||
f"Solution result for task {task_id} and repeat {repeat_id} "
|
||||
"not found.",
|
||||
)
|
||||
|
||||
try:
|
||||
with open(path_file, "r", encoding="utf-8") as f:
|
||||
solution_data = json.load(f)
|
||||
except JSONDecodeError as e:
|
||||
raise JSONDecodeError(
|
||||
f"Failed to load JSON from {path_file}: {e.msg}",
|
||||
e.doc,
|
||||
e.pos,
|
||||
) from e
|
||||
|
||||
return SolutionOutput(**solution_data)
|
||||
|
||||
def solution_result_exists(self, task_id: str, repeat_id: str) -> bool:
|
||||
"""Check if the solution for the given task and repeat is finished.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
True if the solution result file exists, False otherwise.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.SOLUTION_FILE_NAME,
|
||||
)
|
||||
|
||||
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
|
||||
|
||||
def evaluation_result_exists(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
metric_name: str,
|
||||
) -> bool:
|
||||
"""Check if the evaluation result for the given solution and metric
|
||||
is finished.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
metric_name (`str`):
|
||||
The name of the metric.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
True if the evaluation result file exists, False otherwise.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.EVALUATION_DIR_NAME,
|
||||
f"{metric_name}.json",
|
||||
)
|
||||
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
|
||||
|
||||
def save_aggregation_result(
|
||||
self,
|
||||
aggregation_result: dict,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Save the aggregation result.
|
||||
|
||||
Args:
|
||||
aggregation_result (`dict`):
|
||||
A dictionary containing the aggregation result.
|
||||
"""
|
||||
path_file = os.path.join(
|
||||
self.save_dir,
|
||||
self.EVALUATION_RESULT_FILE,
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_file), exist_ok=True)
|
||||
with open(path_file, "w", encoding="utf-8") as f:
|
||||
json.dump(aggregation_result, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def aggregation_result_exists(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Check if the aggregation result exists
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if the aggregation result file exists.
|
||||
"""
|
||||
path_file = os.path.join(
|
||||
self.save_dir,
|
||||
self.EVALUATION_RESULT_FILE,
|
||||
)
|
||||
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
|
||||
|
||||
def save_evaluation_meta(self, meta_info: dict) -> None:
|
||||
"""Save the evaluation meta information.
|
||||
|
||||
Args:
|
||||
meta_info (`dict`):
|
||||
A dictionary containing the meta information.
|
||||
"""
|
||||
path_file = os.path.join(
|
||||
self.save_dir,
|
||||
self.EVALUATION_META_FILE,
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_file), exist_ok=True)
|
||||
with open(path_file, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_info, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def save_task_meta(
|
||||
self,
|
||||
task_id: str,
|
||||
meta_info: dict[str, JSONSerializableObject],
|
||||
) -> None:
|
||||
"""Save the task meta information.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
meta_info (`dict[str, JSONSerializableObject]`):
|
||||
The task meta information to be saved, which should be JSON
|
||||
serializable.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
None,
|
||||
self.TASK_META_FILE,
|
||||
)
|
||||
with open(path_file, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_info, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def save_solution_stats(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
stats: dict,
|
||||
) -> None:
|
||||
"""Save the solution statistics information for a given task and
|
||||
repeat ID.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
stats (`dict`):
|
||||
A dictionary containing the solution statistics to be saved.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.SOLUTION_STATS_FILE_NAME,
|
||||
)
|
||||
if not os.path.exists(path_file):
|
||||
with open(path_file, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def get_solution_stats(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
) -> dict:
|
||||
"""Get the solution statistics information for a given task and
|
||||
repeat ID.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
A dictionary containing the solution statistics for the given
|
||||
task and repeat ID.
|
||||
"""
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.SOLUTION_STATS_FILE_NAME,
|
||||
)
|
||||
|
||||
if not os.path.exists(path_file):
|
||||
raise FileNotFoundError(
|
||||
f"Solution statistics for task {task_id} and repeat "
|
||||
f"{repeat_id} not found.",
|
||||
)
|
||||
|
||||
try:
|
||||
with open(path_file, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except JSONDecodeError as e:
|
||||
raise JSONDecodeError(
|
||||
f"Failed to load JSON from {path_file}: {e.msg}",
|
||||
e.doc,
|
||||
e.pos,
|
||||
) from e
|
||||
|
||||
def get_agent_pre_print_hook(
|
||||
self,
|
||||
task_id: str,
|
||||
repeat_id: str,
|
||||
) -> Callable[[AgentBase, dict], None]:
|
||||
"""Get a pre-print hook function for the agent to save the agent
|
||||
printing in the evaluation storage.
|
||||
|
||||
Args:
|
||||
task_id (`str`):
|
||||
The task ID.
|
||||
repeat_id (`str`):
|
||||
The repeat ID for the task, usually the index of the repeat
|
||||
evaluation.
|
||||
|
||||
Returns:
|
||||
`Callable[[AgentBase, dict], None]`:
|
||||
A hook function that takes an `AgentBase` instance and a
|
||||
keyword arguments dictionary as input, saving the agent's
|
||||
printing Msg into the evaluation storage.
|
||||
"""
|
||||
|
||||
def pre_print_hook(_agent: AgentBase, kwargs: dict) -> None:
|
||||
"""Hook function to save agent's printing."""
|
||||
msg: Msg | None = kwargs.get("msg", None)
|
||||
last: bool = kwargs.get("last", False)
|
||||
|
||||
if msg is None or not last:
|
||||
return
|
||||
|
||||
# Only save the last message
|
||||
printing_str = []
|
||||
for block in msg.get_content_blocks():
|
||||
match block["type"]:
|
||||
case "text":
|
||||
printing_str.append(
|
||||
f"{msg.name}: {block['text']}",
|
||||
)
|
||||
case "thinking":
|
||||
printing_str.append(
|
||||
f"{msg.name} (thinking): {block['text']}",
|
||||
)
|
||||
case _:
|
||||
block_str = json.dumps(
|
||||
block,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
if printing_str:
|
||||
printing_str.append(block_str)
|
||||
else:
|
||||
printing_str.append(f"{msg.name}: {block_str}")
|
||||
|
||||
path_file = self._get_save_path(
|
||||
task_id,
|
||||
repeat_id,
|
||||
self.AGENT_PRINTING_LOG,
|
||||
)
|
||||
with open(path_file, "a", encoding="utf-8") as f:
|
||||
f.write("\n".join(printing_str) + "\n")
|
||||
|
||||
return pre_print_hook
|
||||
101
src/agentscope/evaluate/_metric_base.py
Normal file
101
src/agentscope/evaluate/_metric_base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base class for _metric in evaluation."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .._utils._common import _get_timestamp
|
||||
from .._utils._mixin import DictMixin
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricResult(DictMixin):
|
||||
"""The result of a _metric."""
|
||||
|
||||
name: str
|
||||
"""The metric name."""
|
||||
|
||||
result: str | float | int
|
||||
"""The metric result."""
|
||||
|
||||
created_at: str = field(default_factory=_get_timestamp)
|
||||
"""The timestamp when the metric result was created."""
|
||||
|
||||
message: str | None = field(default_factory=lambda: None)
|
||||
"""An optional message for the metric result, can be used to provide
|
||||
additional information or context about the result."""
|
||||
|
||||
metadata: dict[str, JSONSerializableObject] | None = field(default=None)
|
||||
"""Optional metadata for the metric result, can be used to store
|
||||
additional information related to the metric result."""
|
||||
|
||||
|
||||
class MetricType(str, Enum):
|
||||
"""The metric type enum."""
|
||||
|
||||
CATEGORY = "category"
|
||||
"""The metric result is a category, e.g. "pass" or "fail"."""
|
||||
|
||||
NUMERICAL = "numerical"
|
||||
"""The metric result is a numerical value, e.g. 0.95 or 100."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricBase(ABC):
|
||||
"""The base class for _metric in evaluation."""
|
||||
|
||||
name: str
|
||||
"""The name of the Metric"""
|
||||
|
||||
metric_type: MetricType
|
||||
"""The metric type"""
|
||||
|
||||
description: str | None
|
||||
"""The description of the metric"""
|
||||
|
||||
categories: list[str] | None
|
||||
"""The candidate categories. If `metric_type` is "category", the
|
||||
categories must be provided, otherwise it should be `None`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
metric_type: MetricType,
|
||||
description: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the _metric object.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the metric.
|
||||
metric_type (`MetricType`):
|
||||
The type of the metric, can be either "category" or
|
||||
"numerical", which will determine how to display the result.
|
||||
description (`str`):
|
||||
The description of the metric.
|
||||
categories (`list[str] | None`, optional):
|
||||
The candidate categories. If `metric_type` is "category", the
|
||||
categories must be provided, otherwise it should be `None`.
|
||||
"""
|
||||
self.name = name
|
||||
self.metric_type = metric_type
|
||||
self.description = description
|
||||
|
||||
if metric_type == MetricType.CATEGORY and categories is None:
|
||||
raise ValueError(
|
||||
"Categories must be provided for category metrics.",
|
||||
)
|
||||
|
||||
self.categories = categories
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> MetricResult:
|
||||
"""The call function to calculate the _metric result"""
|
||||
36
src/agentscope/evaluate/_solution.py
Normal file
36
src/agentscope/evaluate/_solution.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Solution class for evaluation tasks."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..message import (
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
TextBlock,
|
||||
)
|
||||
from ..types._json import JSONSerializableObject
|
||||
from .._utils._mixin import DictMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolutionOutput(DictMixin):
|
||||
"""The output of a solution in evaluation task"""
|
||||
|
||||
success: bool
|
||||
"""Indicates whether the solution is executed successfully. When the
|
||||
solution raise exception, this should be set to False."""
|
||||
output: JSONSerializableObject
|
||||
"""The final output of the solution."""
|
||||
trajectory: list[ToolUseBlock | ToolResultBlock | TextBlock]
|
||||
"""The tool calls and results trajectory"""
|
||||
meta: dict[str, Any] | None = field(default_factory=lambda: None)
|
||||
"""Additional metadata for the solution"""
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""Custom pickling to handle dataclass + DictMixin inheritance."""
|
||||
return self.__dict__.copy()
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
"""Custom unpickling to handle dataclass + DictMixin inheritance."""
|
||||
self.__dict__.update(state)
|
||||
53
src/agentscope/evaluate/_task.py
Normal file
53
src/agentscope/evaluate/_task.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base class for task in evaluation."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ._solution import SolutionOutput
|
||||
from ._metric_base import MetricBase, MetricResult
|
||||
from ..types._json import JSONSerializableObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""The base class for task in evaluation."""
|
||||
|
||||
id: str
|
||||
"""The unique identifier for the task."""
|
||||
|
||||
input: JSONSerializableObject
|
||||
"""The task input, which should be a JSON serializable object."""
|
||||
|
||||
ground_truth: JSONSerializableObject
|
||||
"""The task ground truth if exists, which should be a JSON serializable
|
||||
object."""
|
||||
|
||||
metrics: list[MetricBase]
|
||||
"""The metrics to evaluate the task, which should be a list of
|
||||
`MetricBase` objects."""
|
||||
|
||||
tags: dict[str, str] | None = field(default_factory=lambda: None)
|
||||
"""Tags to categorize the task, e.g. `{"difficulty": "easy",
|
||||
"cate": "math"}`."""
|
||||
|
||||
metadata: dict[str, Any] | None = field(
|
||||
default_factory=lambda: None,
|
||||
)
|
||||
"""Additional metadata for the task."""
|
||||
|
||||
async def evaluate(self, solution: SolutionOutput) -> list[MetricResult]:
|
||||
"""Evaluate the task with the given solution.
|
||||
|
||||
Args:
|
||||
solution (`SolutionOutput`):
|
||||
The solution to evaluate the task with.
|
||||
|
||||
Returns:
|
||||
`MetricResult`:
|
||||
The result of the evaluation.
|
||||
"""
|
||||
evaluations = []
|
||||
for metric in self.metrics:
|
||||
result = await metric(solution)
|
||||
evaluations.append(result)
|
||||
return evaluations
|
||||
16
src/agentscope/exception/__init__.py
Normal file
16
src/agentscope/exception/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The exception module in agentscope."""
|
||||
|
||||
from ._exception_base import AgentOrientedExceptionBase
|
||||
from ._tool import (
|
||||
ToolInterruptedError,
|
||||
ToolNotFoundError,
|
||||
ToolInvalidArgumentsError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentOrientedExceptionBase",
|
||||
"ToolInterruptedError",
|
||||
"ToolNotFoundError",
|
||||
"ToolInvalidArgumentsError",
|
||||
]
|
||||
18
src/agentscope/exception/_exception_base.py
Normal file
18
src/agentscope/exception/_exception_base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base exception class in agentscope."""
|
||||
|
||||
|
||||
class AgentOrientedExceptionBase(Exception):
|
||||
"""The base class for all agent-oriented exceptions. These exceptions are
|
||||
expect to the captured and exposed to the agent during runtime, so that
|
||||
agents can handle the error appropriately during the runtime.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
"""Initialize the exception with a message."""
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the string representation of the exception."""
|
||||
return f"{self.__class__.__name__}: {self.message}"
|
||||
16
src/agentscope/exception/_tool.py
Normal file
16
src/agentscope/exception/_tool.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The tool-related exceptions in agentscope."""
|
||||
|
||||
from ._exception_base import AgentOrientedExceptionBase
|
||||
|
||||
|
||||
class ToolNotFoundError(AgentOrientedExceptionBase):
|
||||
"""Exception raised when a tool was not found."""
|
||||
|
||||
|
||||
class ToolInterruptedError(AgentOrientedExceptionBase):
|
||||
"""Exception raised when a tool calling was interrupted by the user."""
|
||||
|
||||
|
||||
class ToolInvalidArgumentsError(AgentOrientedExceptionBase):
|
||||
"""Exception raised when the arguments passed to a tool are invalid."""
|
||||
48
src/agentscope/formatter/__init__.py
Normal file
48
src/agentscope/formatter/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The formatter module in agentscope."""
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from ._dashscope_formatter import (
|
||||
DashScopeChatFormatter,
|
||||
DashScopeMultiAgentFormatter,
|
||||
)
|
||||
from ._anthropic_formatter import (
|
||||
AnthropicChatFormatter,
|
||||
AnthropicMultiAgentFormatter,
|
||||
)
|
||||
from ._openai_formatter import (
|
||||
OpenAIChatFormatter,
|
||||
OpenAIMultiAgentFormatter,
|
||||
)
|
||||
from ._gemini_formatter import (
|
||||
GeminiChatFormatter,
|
||||
GeminiMultiAgentFormatter,
|
||||
)
|
||||
from ._ollama_formatter import (
|
||||
OllamaChatFormatter,
|
||||
OllamaMultiAgentFormatter,
|
||||
)
|
||||
from ._deepseek_formatter import (
|
||||
DeepSeekChatFormatter,
|
||||
DeepSeekMultiAgentFormatter,
|
||||
)
|
||||
from ._a2a_formatter import A2AChatFormatter
|
||||
|
||||
__all__ = [
|
||||
"FormatterBase",
|
||||
"TruncatedFormatterBase",
|
||||
"DashScopeChatFormatter",
|
||||
"DashScopeMultiAgentFormatter",
|
||||
"OpenAIChatFormatter",
|
||||
"OpenAIMultiAgentFormatter",
|
||||
"AnthropicChatFormatter",
|
||||
"AnthropicMultiAgentFormatter",
|
||||
"GeminiChatFormatter",
|
||||
"GeminiMultiAgentFormatter",
|
||||
"OllamaChatFormatter",
|
||||
"OllamaMultiAgentFormatter",
|
||||
"DeepSeekChatFormatter",
|
||||
"DeepSeekMultiAgentFormatter",
|
||||
"A2AChatFormatter",
|
||||
]
|
||||
364
src/agentscope/formatter/_a2a_formatter.py
Normal file
364
src/agentscope/formatter/_a2a_formatter.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The A2A message formatter class."""
|
||||
import mimetypes
|
||||
import uuid
|
||||
from typing import Literal, TYPE_CHECKING
|
||||
|
||||
|
||||
from .._logging import logger
|
||||
from ._formatter_base import FormatterBase
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
URLSource,
|
||||
Base64Source,
|
||||
ContentBlock,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import (
|
||||
Message,
|
||||
Task,
|
||||
Part,
|
||||
)
|
||||
else:
|
||||
Message = "a2a.types.Message"
|
||||
Task = "a2a.types.Task"
|
||||
Part = "a2a.types.Part"
|
||||
|
||||
|
||||
class A2AChatFormatter(FormatterBase):
|
||||
"""A2A message formatter class, which convert AgentScope messages into
|
||||
A2A message format."""
|
||||
|
||||
async def format(self, msgs: list[Msg]) -> Message:
|
||||
"""Convert AgentScope messages into a A2A message object. Note that
|
||||
A2A server only supports single request message, so the input msgs
|
||||
list will be merged into a single A2A Message.
|
||||
|
||||
.. note:: Note the A2A protocol receives a single message per request,
|
||||
so multi-message inputs will be merged into one A2A Message with role
|
||||
'user'.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
List of AgentScope Msg objects to be converted.
|
||||
|
||||
Returns:
|
||||
`Message`:
|
||||
The converted A2A Message object.
|
||||
"""
|
||||
|
||||
from a2a.types import (
|
||||
Part,
|
||||
TextPart,
|
||||
FilePart,
|
||||
FileWithUri,
|
||||
FileWithBytes,
|
||||
DataPart,
|
||||
Role,
|
||||
Message,
|
||||
)
|
||||
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
parts = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
block_type = block.get("type")
|
||||
if block_type == "text" and block.get("text"):
|
||||
parts.append(
|
||||
Part(
|
||||
root=TextPart(
|
||||
text=block.get("text"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elif block_type == "thinking" and block.get("thinking"):
|
||||
parts.append(
|
||||
Part(
|
||||
root=TextPart(
|
||||
text=block.get("thinking"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elif block_type in [
|
||||
"image",
|
||||
"video",
|
||||
"audio",
|
||||
] and block.get("source"):
|
||||
source = block.get("source", {})
|
||||
source_type = source.get("type")
|
||||
|
||||
if source_type == "url":
|
||||
parts.append(
|
||||
Part(
|
||||
root=FilePart(
|
||||
file=FileWithUri(
|
||||
uri=source.get("url"),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elif source_type == "base64":
|
||||
parts.append(
|
||||
Part(
|
||||
root=FilePart(
|
||||
file=FileWithBytes(
|
||||
bytes=source.get("data"),
|
||||
mime_type=source.get("media_type"),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported source type: {source_type}",
|
||||
)
|
||||
|
||||
elif block_type in ["tool_use", "tool_result"]:
|
||||
parts.append(
|
||||
Part(
|
||||
root=DataPart(
|
||||
data=block,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
"Unsupported block type %s in A2AFormatter.",
|
||||
block_type,
|
||||
)
|
||||
|
||||
a2a_message = Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.user,
|
||||
parts=parts,
|
||||
)
|
||||
|
||||
return a2a_message
|
||||
|
||||
async def format_a2a_message(self, name: str, message: Message) -> Msg:
|
||||
"""Convert A2A Message object back to AgentScope Msg format.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the message sender.
|
||||
message (`Message`):
|
||||
The A2A Message object to be converted.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
List of converted AgentScope Msg objects.
|
||||
"""
|
||||
|
||||
from a2a.types import Role
|
||||
|
||||
content = []
|
||||
metadata = None
|
||||
for part in message.parts:
|
||||
content.append(
|
||||
await self._format_a2a_part(part),
|
||||
)
|
||||
|
||||
if message.role == Role.user:
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
elif message.role == Role.agent:
|
||||
role = "assistant"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported role: {message.role} in A2A message.",
|
||||
)
|
||||
|
||||
return Msg(
|
||||
name=name,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _guess_type(
|
||||
uri: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> Literal["image", "video", "audio", "unknown"]:
|
||||
"""Guess the content type from the uri or mime type.
|
||||
|
||||
Args:
|
||||
uri (`str | None`, optional):
|
||||
The uri of the content.
|
||||
mime_type (`str | None`, optional):
|
||||
The mime type of the content.
|
||||
|
||||
Returns:
|
||||
`Literal["image", "video", "audio", "unknown"]`:
|
||||
The guessed content type.
|
||||
"""
|
||||
if mime_type is None and uri is None:
|
||||
raise ValueError(
|
||||
"Either uri or mime_type must be provided to guess the"
|
||||
" content type.",
|
||||
)
|
||||
|
||||
if mime_type is None:
|
||||
mime_type, _encoding = mimetypes.guess_type(uri or "")
|
||||
|
||||
if isinstance(mime_type, str):
|
||||
if mime_type.startswith("image/"):
|
||||
return "image"
|
||||
|
||||
if mime_type.startswith("video/"):
|
||||
return "video"
|
||||
|
||||
if mime_type.startswith("audio/"):
|
||||
return "audio"
|
||||
|
||||
return "unknown"
|
||||
|
||||
async def format_a2a_task(self, name: str, task: Task) -> list[Msg]:
|
||||
"""Convert A2A Task object back to AgentScope Msg format.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the message sender.
|
||||
task (`Task`):
|
||||
The A2A Task object to be converted.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
Converted AgentScope Msg objects.
|
||||
"""
|
||||
msgs = []
|
||||
if task.status and task.status.message:
|
||||
msgs.append(
|
||||
await self.format_a2a_message(name, task.status.message),
|
||||
)
|
||||
|
||||
merged_msgs = []
|
||||
for msg in msgs:
|
||||
if merged_msgs and merged_msgs[-1].role == msg.role:
|
||||
merged_msgs[-1].content.extend(msg.content)
|
||||
|
||||
else:
|
||||
merged_msgs.append(msg)
|
||||
|
||||
if task.artifacts:
|
||||
for artifact in task.artifacts:
|
||||
artifact_content = [
|
||||
await self._format_a2a_part(_) for _ in artifact.parts
|
||||
]
|
||||
|
||||
if merged_msgs and merged_msgs[-1].role == "assistant":
|
||||
merged_msgs[-1].content.extend(artifact_content)
|
||||
merged_msgs[-1].metadata = artifact.metadata
|
||||
|
||||
else:
|
||||
merged_msgs.append(
|
||||
Msg(
|
||||
name=name,
|
||||
role="assistant",
|
||||
content=artifact_content,
|
||||
metadata=artifact.metadata,
|
||||
),
|
||||
)
|
||||
|
||||
return merged_msgs
|
||||
|
||||
async def _format_a2a_part(self, part: Part) -> ContentBlock:
|
||||
"""Convert a single A2A Part object into AgentScope ContentBlock.
|
||||
|
||||
.. note:: We will try to convert the `DataPart` into tool use and tool
|
||||
result blocks if possible.
|
||||
|
||||
Args:
|
||||
part (`Part`):
|
||||
The A2A Part object to be converted.
|
||||
|
||||
Returns:
|
||||
`ContentBlock`:
|
||||
The converted AgentScope ContentBlock.
|
||||
"""
|
||||
|
||||
from a2a.types import (
|
||||
TextPart,
|
||||
FilePart,
|
||||
FileWithUri,
|
||||
FileWithBytes,
|
||||
DataPart,
|
||||
)
|
||||
|
||||
if isinstance(part.root, TextPart):
|
||||
return TextBlock(
|
||||
type="text",
|
||||
text=part.root.text,
|
||||
)
|
||||
|
||||
if isinstance(part.root, FilePart):
|
||||
if isinstance(part.root.file, FileWithUri):
|
||||
return { # type: ignore[return-value, misc]
|
||||
"type": self._guess_type(
|
||||
part.root.file.uri,
|
||||
part.root.file.mime_type,
|
||||
),
|
||||
"source": URLSource(
|
||||
type="url",
|
||||
url=part.root.file.uri,
|
||||
),
|
||||
}
|
||||
|
||||
if isinstance(part.root.file, FileWithBytes):
|
||||
return { # type: ignore[return-value, misc]
|
||||
"type": self._guess_type(
|
||||
mime_type=part.root.file.mime_type,
|
||||
),
|
||||
"source": Base64Source(
|
||||
type="base64",
|
||||
media_type=part.root.file.mime_type
|
||||
or "application/octet-stream",
|
||||
data=part.root.file.bytes,
|
||||
),
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported File type: {type(part.root.file)} in A2A"
|
||||
"message.",
|
||||
)
|
||||
|
||||
if isinstance(part.root, DataPart):
|
||||
# Maybe the tool use and tool result blocks
|
||||
if {
|
||||
"type",
|
||||
"name",
|
||||
"input",
|
||||
"id",
|
||||
} <= part.root.data.keys() and part.root.data[
|
||||
"type"
|
||||
] == "tool_use":
|
||||
return part.root.data
|
||||
|
||||
if {
|
||||
"type",
|
||||
"name",
|
||||
"output",
|
||||
"id",
|
||||
} <= part.root.data.keys() and part.root.data[
|
||||
"type"
|
||||
] == "tool_result":
|
||||
return part.root.data
|
||||
|
||||
# TODO: what about the other data parts?
|
||||
return TextBlock(
|
||||
type="text",
|
||||
text=str(part.root.data),
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported Part type: {type(part.root)} in A2A message"
|
||||
f": {part.root}",
|
||||
)
|
||||
253
src/agentscope/formatter/_anthropic_formatter.py
Normal file
253
src/agentscope/formatter/_anthropic_formatter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The Anthropic formatter module."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
class AnthropicChatFormatter(TruncatedFormatterBase):
|
||||
"""The Anthropic formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into Anthropic API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
|
||||
.. note:: Anthropic suggests always passing all previous thinking
|
||||
blocks back to the API in subsequent calls to maintain reasoning
|
||||
continuity. For more details, please refer to
|
||||
`Anthropic's documentation
|
||||
<https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#preserving-thinking-blocks>`_.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for index, msg in enumerate(msgs):
|
||||
content_blocks = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ in ["thinking", "text", "image"]:
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
content_blocks.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "tool_use",
|
||||
"name": block.get("name"),
|
||||
"input": block.get("input", {}),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
output = block.get("output")
|
||||
if output is None:
|
||||
content_value = [{"type": "text", "text": None}]
|
||||
elif isinstance(output, list):
|
||||
content_value = output
|
||||
else:
|
||||
content_value = [{"type": "text", "text": str(output)}]
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.get("id"),
|
||||
"content": content_value,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
# Claude only allow the first message to be system message
|
||||
if msg.role == "system" and index != 0:
|
||||
role = "user"
|
||||
else:
|
||||
role = msg.role
|
||||
|
||||
msg_anthropic = {
|
||||
"role": role,
|
||||
"content": content_blocks or None,
|
||||
}
|
||||
|
||||
# When both content and tool_calls are None, skipped
|
||||
if msg_anthropic["content"] or msg_anthropic.get("tool_calls"):
|
||||
messages.append(msg_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class AnthropicMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
Anthropic formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Anthropic API."""
|
||||
return await AnthropicChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Anthropic API."""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required Anthropic format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
"type": "text",
|
||||
},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
conversation_blocks.append({**block})
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
"type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"type": "text",
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append(
|
||||
{"type": "text", "text": "</history>"},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_msgs
|
||||
639
src/agentscope/formatter/_dashscope_formatter.py
Normal file
639
src/agentscope/formatter/_dashscope_formatter.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The dashscope formatter module."""
|
||||
|
||||
import json
|
||||
import os.path
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from .._utils._common import _is_accessible_local_file
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
URLSource,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_dashscope_media_block(
|
||||
block: ImageBlock | AudioBlock,
|
||||
) -> dict[str, str]:
|
||||
"""Format an image or audio block for DashScope API.
|
||||
|
||||
Args:
|
||||
block (`ImageBlock` | `AudioBlock`):
|
||||
The image or audio block to format.
|
||||
|
||||
Returns:
|
||||
`dict[str, str]`:
|
||||
A dictionary with "image" or "audio" key and the formatted URL or
|
||||
data URI as value.
|
||||
|
||||
Raises:
|
||||
`NotImplementedError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
typ = block["type"]
|
||||
source = block["source"]
|
||||
if source["type"] == "url":
|
||||
url = source["url"]
|
||||
if _is_accessible_local_file(url):
|
||||
return {typ: "file://" + os.path.abspath(url)}
|
||||
else:
|
||||
# treat as web url
|
||||
return {typ: url}
|
||||
|
||||
elif source["type"] == "base64":
|
||||
media_type = source["media_type"]
|
||||
base64_data = source["data"]
|
||||
return {
|
||||
typ: f"data:{media_type};base64,{base64_data}",
|
||||
}
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported source type '{source.get('type')}' "
|
||||
f"for {typ} block.",
|
||||
)
|
||||
|
||||
|
||||
def _reformat_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Reformat the content to be compatible with HuggingFaceTokenCounter.
|
||||
|
||||
This function processes a list of messages and converts multi-part
|
||||
text content into single string content when all parts are plain text.
|
||||
This is necessary for compatibility with HuggingFaceTokenCounter which
|
||||
expects simple string content rather than structured content with
|
||||
multiple parts.
|
||||
|
||||
Args:
|
||||
messages (list[dict[str, Any]]):
|
||||
A list of message dictionaries where each message may contain a
|
||||
"content" field. The content can be either:
|
||||
- A string (unchanged)
|
||||
- A list of content items, where each item is a dict that may
|
||||
contain "text", "type", and other fields
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]:
|
||||
A list of reformatted messages. For messages where all content
|
||||
items are plain text (have "text" field and either no "type"
|
||||
field or "type" == "text"), the content list is converted to a
|
||||
single newline-joined string. Other messages remain unchanged.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Case 1: All text content - will be converted
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello", "type": "text"},
|
||||
{"text": "World", "type": "text"}
|
||||
]
|
||||
}
|
||||
]
|
||||
result = _reformat_messages(messages)
|
||||
print(result[0]["content"])
|
||||
# Output: "Hello\nWorld"
|
||||
|
||||
# Case 2: Mixed content - will remain unchanged
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello", "type": "text"},
|
||||
{"image_url": "...", "type": "image"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
result = _reformat_messages(messages) # remain unchanged
|
||||
print(type(result[0]["content"]))
|
||||
# Output: <class 'list'>
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
content = message.get("content", [])
|
||||
|
||||
is_all_text = True
|
||||
texts = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or "text" not in item:
|
||||
is_all_text = False
|
||||
break
|
||||
if "type" in item and item["type"] != "text":
|
||||
is_all_text = False
|
||||
break
|
||||
if item["text"]:
|
||||
texts.append(item["text"])
|
||||
|
||||
if is_all_text and texts:
|
||||
message["content"] = "\n".join(texts)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class DashScopeChatFormatter(TruncatedFormatterBase):
|
||||
"""The DashScope formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
|
||||
.. warning::
|
||||
Known Issues with DashScope API:
|
||||
|
||||
1. **Missing content field**: When messages lack the 'content' field,
|
||||
qwen-vl-max models will raise ``KeyError: 'content'``.
|
||||
|
||||
2. **None content value**: When content is ``None``, qwen-vl-max models
|
||||
will raise ``TypeError: 'NoneType' object is not iterable``.
|
||||
|
||||
3. **Empty text in content**: When content contains
|
||||
``[{"text": None}]``, qwen3-max may repeatedly invoke tools
|
||||
multiple times. Note that when qwen3-max initiates tool calls,
|
||||
the returned message contains ``"content": ""``.
|
||||
|
||||
To avoid these issues, this formatter assigns content as an empty
|
||||
list ``[]`` for messages without valid content blocks.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
promote_tool_result_audios: bool = False,
|
||||
promote_tool_result_videos: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_audios (`bool`, defaults to `False`):
|
||||
Whether to promote audios from tool results to user messages.
|
||||
Most LLM APIs don't support audios in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, audios are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_videos (`bool`, defaults to `False`):
|
||||
Whether to promote videos from tool results to user messages.
|
||||
Most LLM APIs don't support videos in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, videos are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter, max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
self.promote_tool_result_audios = promote_tool_result_audios
|
||||
self.promote_tool_result_videos = promote_tool_result_videos
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into DashScope API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
|
||||
if typ == "text":
|
||||
content_blocks.append(
|
||||
{
|
||||
"text": block.get("text"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ in ["image", "audio", "video"]:
|
||||
content_blocks.append(
|
||||
_format_dashscope_media_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
# First add the tool result message in DashScope API format
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": textual_output,
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Then, handle the multimodal data if any
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
elif (
|
||||
multimodal_block["type"] == "audio"
|
||||
and self.promote_tool_result_audios
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The audio from '{url}': ",
|
||||
),
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
elif (
|
||||
multimodal_block["type"] == "video"
|
||||
and self.promote_tool_result_videos
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The video from '{url}': ",
|
||||
),
|
||||
VideoBlock(
|
||||
type="video",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
f"the media contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
msg_dashscope = {
|
||||
"role": msg.role,
|
||||
"content": content_blocks,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_dashscope["tool_calls"] = tool_calls
|
||||
|
||||
if msg_dashscope["content"] or msg_dashscope.get("tool_calls"):
|
||||
formatted_msgs.append(msg_dashscope)
|
||||
|
||||
# Move to next message
|
||||
i += 1
|
||||
|
||||
return _reformat_messages(formatted_msgs)
|
||||
|
||||
|
||||
class DashScopeMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""DashScope formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
|
||||
.. note:: This formatter will combine previous messages (except tool
|
||||
calls/results) into a history section in the first system message with
|
||||
the conversation history prompt.
|
||||
|
||||
.. note:: For tool calls/results, they will be presented as separate
|
||||
messages as required by the DashScope API. Therefore, the tool calls/
|
||||
results messages are expected to be placed at the end of the input
|
||||
messages.
|
||||
|
||||
.. tip:: Telling the assistant's name in the system prompt is very
|
||||
important in multi-agent conversations. So that LLM can know who it
|
||||
is playing as.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
promote_tool_result_audios: bool = False,
|
||||
promote_tool_result_videos: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_audios (`bool`, defaults to `False`):
|
||||
Whether to promote audios from tool results to user messages.
|
||||
Most LLM APIs don't support audios in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, audios are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_videos (`bool`, defaults to `False`):
|
||||
Whether to promote videos from tool results to user messages.
|
||||
Most LLM APIs don't support videos in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, videos are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
self.promote_tool_result_audios = promote_tool_result_audios
|
||||
self.promote_tool_result_videos = promote_tool_result_videos
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the DashScope API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DashScope API.
|
||||
"""
|
||||
return await DashScopeChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
promote_tool_result_audios=self.promote_tool_result_audios,
|
||||
promote_tool_result_videos=self.promote_tool_result_videos,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into a user message with conversation history tags. For the
|
||||
first agent message, it will include the conversation history prompt.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DashScope API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required DashScope format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] in ["image", "audio", "video"]:
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
if block["source"]["type"] == "url":
|
||||
url = block["source"]["url"]
|
||||
if _is_accessible_local_file(url):
|
||||
conversation_blocks.append(
|
||||
{
|
||||
block["type"]: "file://"
|
||||
+ os.path.abspath(url),
|
||||
},
|
||||
)
|
||||
else:
|
||||
conversation_blocks.append({block["type"]: url})
|
||||
|
||||
elif block["source"]["type"] == "base64":
|
||||
media_type = block["source"]["media_type"]
|
||||
base64_data = block["source"]["data"]
|
||||
conversation_blocks.append(
|
||||
{
|
||||
block[
|
||||
"type"
|
||||
]: f"data:{media_type};base64,{base64_data}",
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, "
|
||||
"skipped.",
|
||||
block["type"],
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append({"text": "\n".join(accumulated_text)})
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
return _reformat_messages(formatted_msgs)
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for DashScope API."""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_text_content(),
|
||||
}
|
||||
265
src/agentscope/formatter/_deepseek_formatter.py
Normal file
265
src/agentscope/formatter/_deepseek_formatter.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The DeepSeek formatter module."""
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import Msg, TextBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
class DeepSeekChatFormatter(TruncatedFormatterBase):
|
||||
"""The DeepSeek formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = False
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into DeepSeek API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for msg in msgs:
|
||||
content_blocks: list = []
|
||||
reasoning_content_blocks: list = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
elif typ == "thinking":
|
||||
reasoning_content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
textual_output, _ = self.convert_tool_result_to_string(
|
||||
block.get("output"), # type: ignore[arg-type]
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": textual_output,
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
content_msg = "\n".join(
|
||||
content.get("text", "") for content in content_blocks
|
||||
)
|
||||
reasoning_msg = "\n".join(
|
||||
reasoning.get("thinking", "")
|
||||
for reasoning in reasoning_content_blocks
|
||||
)
|
||||
|
||||
msg_deepseek = {
|
||||
"role": msg.role,
|
||||
"content": content_msg or None,
|
||||
}
|
||||
|
||||
if reasoning_msg:
|
||||
msg_deepseek["reasoning_content"] = reasoning_msg
|
||||
|
||||
if tool_calls:
|
||||
msg_deepseek["tool_calls"] = tool_calls
|
||||
|
||||
if msg_deepseek["content"] or msg_deepseek.get("tool_calls"):
|
||||
messages.append(msg_deepseek)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class DeepSeekMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
DeepSeek formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = False
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DeepSeek multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the DeepSeek API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DeepSeek API.
|
||||
"""
|
||||
return await DeepSeekChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the DeepSeek API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DeepSeek API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required DeepSeek format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": conversation_blocks_text,
|
||||
}
|
||||
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
129
src/agentscope/formatter/_formatter_base.py
Normal file
129
src/agentscope/formatter/_formatter_base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The formatter module."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, Tuple, Sequence
|
||||
|
||||
from .._utils._common import _save_base64_data
|
||||
from ..message import Msg, AudioBlock, ImageBlock, TextBlock, VideoBlock
|
||||
|
||||
|
||||
class FormatterBase:
|
||||
"""The base class for formatters."""
|
||||
|
||||
@abstractmethod
|
||||
async def format(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
"""Format the Msg objects to a list of dictionaries that satisfy the
|
||||
API requirements."""
|
||||
|
||||
@staticmethod
|
||||
def assert_list_of_msgs(msgs: list[Msg]) -> None:
|
||||
"""Assert that the input is a list of Msg objects.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be validated.
|
||||
"""
|
||||
if not isinstance(msgs, list):
|
||||
raise TypeError("Input must be a list of Msg objects.")
|
||||
|
||||
for msg in msgs:
|
||||
if not isinstance(msg, Msg):
|
||||
raise TypeError(
|
||||
f"Expected Msg object, got {type(msg)} instead.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_tool_result_to_string(
|
||||
output: str | List[TextBlock | ImageBlock | AudioBlock | VideoBlock],
|
||||
) -> tuple[
|
||||
str,
|
||||
Sequence[
|
||||
Tuple[
|
||||
str,
|
||||
ImageBlock | AudioBlock | TextBlock | VideoBlock,
|
||||
]
|
||||
],
|
||||
]:
|
||||
"""Turn the tool result list into a textual output to be compatible
|
||||
with the LLM API that doesn't support multimodal data in the tool
|
||||
result.
|
||||
|
||||
For URL-based images, the URL is included in the list. For
|
||||
base64-encoded images, the local file path where the image is saved
|
||||
is included in the returned list.
|
||||
|
||||
Args:
|
||||
output (`str | List[TextBlock | ImageBlock | AudioBlock | \
|
||||
VideoBlock]`):
|
||||
The output of the tool response, including text and multimodal
|
||||
data like images and audio.
|
||||
|
||||
Returns:
|
||||
`tuple[str, list[Tuple[str, ImageBlock | AudioBlock | VideoBlock \
|
||||
TextBlock]]]`:
|
||||
A tuple containing the textual representation of the tool
|
||||
result and a list of tuples. The first element of each tuple
|
||||
is the local file path or URL of the multimodal data, and the
|
||||
second element is the corresponding block.
|
||||
"""
|
||||
|
||||
if isinstance(output, str):
|
||||
return output, []
|
||||
|
||||
textual_output = []
|
||||
multimodal_data = []
|
||||
for block in output:
|
||||
assert isinstance(block, dict) and "type" in block, (
|
||||
f"Invalid block: {block}, a TextBlock, ImageBlock, "
|
||||
f"AudioBlock, or VideoBlock is expected."
|
||||
)
|
||||
if block["type"] == "text":
|
||||
textual_output.append(block["text"])
|
||||
|
||||
elif block["type"] in ["image", "audio", "video"]:
|
||||
assert "source" in block, (
|
||||
f"Invalid {block['type']} block: {block}, 'source' key "
|
||||
"is required."
|
||||
)
|
||||
source = block["source"]
|
||||
# Save the image locally and return the file path
|
||||
if source["type"] == "url":
|
||||
textual_output.append(
|
||||
f"The returned {block['type']} can be found "
|
||||
f"at: {source['url']}",
|
||||
)
|
||||
|
||||
path_multimodal_file = source["url"]
|
||||
|
||||
elif source["type"] == "base64":
|
||||
path_multimodal_file = _save_base64_data(
|
||||
source["media_type"],
|
||||
source["data"],
|
||||
)
|
||||
textual_output.append(
|
||||
f"The returned {block['type']} can be found "
|
||||
f"at: {path_multimodal_file}",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid image source: {block['source']}, "
|
||||
"expected 'url' or 'base64'.",
|
||||
)
|
||||
|
||||
multimodal_data.append(
|
||||
(path_multimodal_file, block),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported block type: {block['type']}, "
|
||||
"expected 'text', 'image', 'audio', or 'video'.",
|
||||
)
|
||||
|
||||
if len(textual_output) == 1:
|
||||
return textual_output[0], multimodal_data
|
||||
|
||||
else:
|
||||
return "\n".join("- " + _ for _ in textual_output), multimodal_data
|
||||
507
src/agentscope/formatter/_gemini_formatter.py
Normal file
507
src/agentscope/formatter/_gemini_formatter.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""Google gemini API formatter in agentscope."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._utils._common import _get_bytes_from_web_url
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
VideoBlock,
|
||||
URLSource,
|
||||
)
|
||||
from .._logging import logger
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_gemini_media_block(
|
||||
media_block: ImageBlock | AudioBlock | VideoBlock,
|
||||
) -> dict[str, Any]:
|
||||
"""Format an image/audio/video block for Gemini API.
|
||||
|
||||
Args:
|
||||
media_block (`ImageBlock | AudioBlock | VideoBlock`):
|
||||
The media block to format.
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`:
|
||||
A dictionary with "inline_data" key in Gemini format.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
source = media_block["source"]
|
||||
if source["type"] == "base64":
|
||||
return {
|
||||
"inline_data": {
|
||||
"data": source["data"],
|
||||
"mime_type": source["media_type"],
|
||||
},
|
||||
}
|
||||
elif source["type"] == "url":
|
||||
return {
|
||||
"inline_data": _to_gemini_inline_data(source["url"]),
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported source type: {source['type']}",
|
||||
)
|
||||
|
||||
|
||||
def _to_gemini_inline_data(url: str) -> dict:
|
||||
"""Convert url into the Gemini API required format."""
|
||||
parsed_url = urlparse(url)
|
||||
extension = url.split(".")[-1].lower()
|
||||
|
||||
# Pre-calculate media type from extension (image/audio/video).
|
||||
typ = None
|
||||
for k, v in GeminiChatFormatter.supported_extensions.items():
|
||||
if extension in v:
|
||||
typ = k
|
||||
break
|
||||
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
# Web url
|
||||
if typ is None:
|
||||
raise TypeError(
|
||||
f"Unsupported file extension: {extension}, expected "
|
||||
f"{GeminiChatFormatter.supported_extensions}",
|
||||
)
|
||||
|
||||
data = _get_bytes_from_web_url(url)
|
||||
return {
|
||||
"data": data,
|
||||
"mime_type": f"{typ}/{extension}",
|
||||
}
|
||||
|
||||
elif os.path.exists(url):
|
||||
# Local file
|
||||
if typ is None:
|
||||
raise TypeError(
|
||||
f"Unsupported file extension: {extension}, expected "
|
||||
f"{GeminiChatFormatter.supported_extensions}",
|
||||
)
|
||||
|
||||
with open(url, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"mime_type": f"{typ}/{extension}",
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
f"The URL `{url}` is not a valid image URL or local file.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiChatFormatter(TruncatedFormatterBase):
|
||||
"""The Gemini formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
supported_extensions: dict[str, list[str]] = {
|
||||
"image": ["png", "jpeg", "webp", "heic", "heif"],
|
||||
"video": [
|
||||
"mp4",
|
||||
"mpeg",
|
||||
"mov",
|
||||
"avi",
|
||||
"x-flv",
|
||||
"mpg",
|
||||
"webm",
|
||||
"wmv",
|
||||
"3gpp",
|
||||
],
|
||||
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter, max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict]:
|
||||
"""Format message objects into Gemini API required format."""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list = []
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
parts = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
parts.append(
|
||||
{
|
||||
"text": block.get("text"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_use":
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"id": None,
|
||||
"name": block["name"],
|
||||
"args": block["input"],
|
||||
},
|
||||
"thought_signature": block.get("id", None),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
# First add the tool result message in DashScope API format
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"id": block["id"],
|
||||
"name": block["name"],
|
||||
"response": {
|
||||
"output": textual_output,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
"the image contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
elif typ in ["image", "audio", "video"]:
|
||||
parts.append(
|
||||
_format_gemini_media_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type: %s in the message, skipped. ",
|
||||
typ,
|
||||
)
|
||||
|
||||
role = "model" if msg.role == "assistant" else "user"
|
||||
|
||||
if parts:
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"parts": parts,
|
||||
},
|
||||
)
|
||||
|
||||
# Move to next message (including inserted messages, which will
|
||||
# be processed in subsequent iterations)
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""The multi-agent formatter for Google Gemini API, where more than a
|
||||
user and an agent are involved.
|
||||
|
||||
.. note:: This formatter will combine previous messages (except tool
|
||||
calls/results) into a history section in the first system message with
|
||||
the conversation history prompt.
|
||||
|
||||
.. note:: For tool calls/results, they will be presented as separate
|
||||
messages as required by the Gemini API. Therefore, the tool calls/
|
||||
results messages are expected to be placed at the end of the input
|
||||
messages.
|
||||
|
||||
.. tip:: Telling the assistant's name in the system prompt is very
|
||||
important in multi-agent conversations. So that LLM can know who it
|
||||
is playing as.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to be used for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the Gemini API."""
|
||||
return {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"text": msg.get_text_content(),
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Gemini API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Gemini API.
|
||||
"""
|
||||
return await GeminiChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Gemini API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Gemini API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into Gemini API required format
|
||||
formatted_msgs: list = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_parts: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] in ["image", "video", "audio"]:
|
||||
# handle the accumulated text as a single part if exists
|
||||
if accumulated_text:
|
||||
conversation_parts.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
# handle the multimodal data
|
||||
conversation_parts.append(
|
||||
_format_gemini_media_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_parts.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
},
|
||||
)
|
||||
|
||||
# Add prompt and <history></history> tags around conversation history
|
||||
if conversation_parts:
|
||||
if conversation_parts[0].get("text"):
|
||||
conversation_parts[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>"
|
||||
+ conversation_parts[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_parts.insert(
|
||||
0,
|
||||
{"text": conversation_history_prompt + "<history>"},
|
||||
)
|
||||
|
||||
if conversation_parts[-1].get("text"):
|
||||
conversation_parts[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_parts.append(
|
||||
{"text": "</history>"},
|
||||
)
|
||||
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": conversation_parts,
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_msgs
|
||||
441
src/agentscope/formatter/_ollama_formatter.py
Normal file
441
src/agentscope/formatter/_ollama_formatter.py
Normal file
@@ -0,0 +1,441 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The Ollama formatter module."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from .._utils._common import _get_bytes_from_web_url
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
URLSource,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_ollama_image_block(
|
||||
image_block: ImageBlock,
|
||||
) -> str:
|
||||
"""Format an image block for Ollama API.
|
||||
|
||||
Args:
|
||||
image_block (`ImageBlock`):
|
||||
The image block to format.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
Base64 encoded image data as a string.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
source = image_block["source"]
|
||||
if source["type"] == "url":
|
||||
return _convert_ollama_image_url_to_base64_data(source["url"])
|
||||
elif source["type"] == "base64":
|
||||
return source["data"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image source type: {source['type']}",
|
||||
)
|
||||
|
||||
|
||||
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
|
||||
"""Convert image url to base64."""
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
# Web url
|
||||
data = _get_bytes_from_web_url(url)
|
||||
return data
|
||||
if os.path.exists(url):
|
||||
# Local file
|
||||
with open(url, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
return data
|
||||
|
||||
raise ValueError(
|
||||
f"The URL `{url}` is not a valid image URL or local file.",
|
||||
)
|
||||
|
||||
|
||||
class OllamaChatFormatter(TruncatedFormatterBase):
|
||||
"""The Ollama formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
participants in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Ollama chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter, max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into Ollama API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list = []
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
content_blocks: list = []
|
||||
tool_calls = []
|
||||
images = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": block.get("input", {}),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": textual_output,
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Then, handle the multimodal data if any
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
"the image contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "image":
|
||||
images.append(
|
||||
_format_ollama_image_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
content_msg = "\n".join(
|
||||
content.get("text", "") for content in content_blocks
|
||||
)
|
||||
msg_ollama = {
|
||||
"role": msg.role,
|
||||
"content": content_msg or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_ollama["tool_calls"] = tool_calls
|
||||
|
||||
if images:
|
||||
msg_ollama["images"] = images
|
||||
|
||||
if (
|
||||
msg_ollama["content"]
|
||||
or msg_ollama.get("images")
|
||||
or msg_ollama.get("tool_calls")
|
||||
):
|
||||
messages.append(msg_ollama)
|
||||
|
||||
# Move to next message
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
Ollama formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Ollama multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the Ollama API."""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_text_content(),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Ollama API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Ollama API.
|
||||
"""
|
||||
return await OllamaChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Ollama API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the ollama API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required Ollama format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
images = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
images.append(_format_ollama_image_block(block))
|
||||
conversation_blocks.append({**block})
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": conversation_blocks_text,
|
||||
}
|
||||
if images:
|
||||
user_message["images"] = images
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
530
src/agentscope/formatter/_openai_formatter.py
Normal file
530
src/agentscope/formatter/_openai_formatter.py
Normal file
@@ -0,0 +1,530 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches, too-many-nested-blocks
|
||||
"""The OpenAI formatter for agentscope."""
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import (
|
||||
Msg,
|
||||
URLSource,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
Base64Source,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_openai_image_block(
|
||||
image_block: ImageBlock,
|
||||
) -> dict[str, Any]:
|
||||
"""Format an image block for OpenAI API.
|
||||
|
||||
Args:
|
||||
image_block (`ImageBlock`):
|
||||
The image block to format.
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`:
|
||||
A dictionary with "type" and "image_url" keys in OpenAI format.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
source = image_block["source"]
|
||||
if source["type"] == "url":
|
||||
url = _to_openai_image_url(source["url"])
|
||||
elif source["type"] == "base64":
|
||||
data = source["data"]
|
||||
media_type = source["media_type"]
|
||||
url = f"data:{media_type};base64,{data}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image source type: {source['type']}",
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _to_openai_image_url(url: str) -> str:
|
||||
"""Convert an image url to openai format. If the given url is a local
|
||||
file, it will be converted to base64 format. Otherwise, it will be
|
||||
returned directly.
|
||||
|
||||
Args:
|
||||
url (`str`):
|
||||
The local or public url of the image.
|
||||
"""
|
||||
# See https://platform.openai.com/docs/guides/vision for details of
|
||||
# support image extensions.
|
||||
support_image_extensions = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
lower_url = url.lower()
|
||||
|
||||
# Web url
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
path_lower = parsed_url.path if parsed_url.path else parsed_url.netloc
|
||||
if any(path_lower.endswith(_) for _ in support_image_extensions):
|
||||
return url
|
||||
|
||||
# Check if it is a local file
|
||||
elif os.path.exists(url) and os.path.isfile(url):
|
||||
if any(lower_url.endswith(_) for _ in support_image_extensions):
|
||||
with open(url, "rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode(
|
||||
"utf-8",
|
||||
)
|
||||
extension = parsed_url.path.lower().split(".")[-1]
|
||||
mime_type = f"image/{extension}"
|
||||
return f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
raise TypeError(f'"{url}" should end with {support_image_extensions}.')
|
||||
|
||||
|
||||
def _to_openai_audio_data(source: URLSource | Base64Source) -> dict:
|
||||
"""Covert an audio source to OpenAI format."""
|
||||
if source["type"] == "url":
|
||||
extension = source["url"].split(".")[-1].lower()
|
||||
if extension not in ["wav", "mp3"]:
|
||||
raise TypeError(
|
||||
f"Unsupported audio file extension: {extension}, "
|
||||
"wav and mp3 are supported.",
|
||||
)
|
||||
|
||||
parsed_url = urlparse(source["url"])
|
||||
|
||||
if os.path.exists(source["url"]):
|
||||
with open(source["url"], "rb") as audio_file:
|
||||
data = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||
|
||||
# web url
|
||||
elif parsed_url.scheme != "":
|
||||
response = requests.get(source["url"])
|
||||
response.raise_for_status()
|
||||
data = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported audio source: {source['url']}, "
|
||||
"it should be a local file or a web URL.",
|
||||
)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"format": extension,
|
||||
}
|
||||
|
||||
if source["type"] == "base64":
|
||||
data = source["data"]
|
||||
media_type = source["media_type"]
|
||||
|
||||
if media_type not in ["audio/wav", "audio/mp3"]:
|
||||
raise TypeError(
|
||||
f"Unsupported audio media type: {media_type}, "
|
||||
"only audio/wav and audio/mp3 are supported.",
|
||||
)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"format": media_type.split("/")[-1],
|
||||
}
|
||||
|
||||
raise TypeError(f"Unsupported audio source: {source['type']}.")
|
||||
|
||||
|
||||
class OpenAIChatFormatter(TruncatedFormatterBase):
|
||||
"""The OpenAI formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `name` field in OpenAI API to
|
||||
identify different entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversation"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision models"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""Supported message blocks for OpenAI API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into OpenAI API required format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of Msg objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries, where each dictionary has "name",
|
||||
"role", and "content" keys.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
content_blocks = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": ( # type: ignore[arg-type]
|
||||
textual_output
|
||||
),
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Then, handle the multimodal data if any
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
"the image contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "image":
|
||||
content_blocks.append(
|
||||
_format_openai_image_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "audio":
|
||||
# Filter out audio content when the multimodal model
|
||||
# outputs both text and audio, to prevent errors in
|
||||
# subsequent model calls
|
||||
if msg.role == "assistant":
|
||||
continue
|
||||
input_audio = _to_openai_audio_data(block["source"])
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": input_audio,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
msg_openai = {
|
||||
"role": msg.role,
|
||||
"name": msg.name,
|
||||
"content": content_blocks or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_openai["tool_calls"] = tool_calls
|
||||
|
||||
# When both content and tool_calls are None, skipped
|
||||
if msg_openai["content"] or msg_openai.get("tool_calls"):
|
||||
messages.append(msg_openai)
|
||||
|
||||
# Move to next message
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OpenAIMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
OpenAI formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
.. tip:: This formatter is compatible with OpenAI API and
|
||||
OpenAI-compatible services like vLLM, Azure OpenAI, and others.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversation"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision models"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""Supported message blocks for OpenAI API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the OpenAI API."""
|
||||
return await OpenAIChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the OpenAI API."""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required OpenAI format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
images = []
|
||||
audios = []
|
||||
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
images.append(_format_openai_image_block(block))
|
||||
elif block["type"] == "audio":
|
||||
# Filter out audio content when the multimodal model
|
||||
# outputs both text and audio, to prevent errors in
|
||||
# subsequent model calls
|
||||
if msg.role == "assistant":
|
||||
continue
|
||||
input_audio = _to_openai_audio_data(block["source"])
|
||||
audios.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": input_audio,
|
||||
},
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
content_list: list[dict[str, Any]] = []
|
||||
if conversation_blocks_text:
|
||||
content_list.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": conversation_blocks_text,
|
||||
},
|
||||
)
|
||||
if images:
|
||||
content_list.extend(images)
|
||||
if audios:
|
||||
content_list.extend(audios)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": content_list,
|
||||
}
|
||||
|
||||
if content_list:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
297
src/agentscope/formatter/_truncated_formatter_base.py
Normal file
297
src/agentscope/formatter/_truncated_formatter_base.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The truncated formatter base class, which allows to truncate the input
|
||||
messages."""
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Tuple,
|
||||
Literal,
|
||||
AsyncGenerator,
|
||||
)
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ..message import Msg
|
||||
from ..token import TokenCounterBase
|
||||
from ..tracing import trace_format
|
||||
|
||||
|
||||
class TruncatedFormatterBase(FormatterBase, ABC):
|
||||
"""Base class for truncated formatters, which formats input messages into
|
||||
required formats with tokens under a specified limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the TruncatedFormatterBase.
|
||||
|
||||
Args:
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
self.token_counter = token_counter
|
||||
|
||||
assert (
|
||||
max_tokens is None or 0 < max_tokens
|
||||
), "max_tokens must be greater than 0"
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@trace_format
|
||||
async def format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
**kwargs: Any,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. If token
|
||||
counter and max token limit are provided, the messages will be
|
||||
truncated to fit the limit.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be formatted.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages in the required format.
|
||||
"""
|
||||
|
||||
# Check if the input messages are valid
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
msgs = deepcopy(msgs)
|
||||
|
||||
while True:
|
||||
formatted_msgs = await self._format(msgs)
|
||||
n_tokens = await self._count(formatted_msgs)
|
||||
|
||||
if (
|
||||
n_tokens is None
|
||||
or self.max_tokens is None
|
||||
or n_tokens <= self.max_tokens
|
||||
):
|
||||
return formatted_msgs
|
||||
|
||||
# truncate the input messages
|
||||
msgs = await self._truncate(msgs)
|
||||
|
||||
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. This method
|
||||
should be implemented by the subclasses."""
|
||||
|
||||
formatted_msgs = []
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
formatted_msgs.append(
|
||||
await self._format_system_message(msgs[0]),
|
||||
)
|
||||
start_index = 1
|
||||
|
||||
is_first_agent_message = True
|
||||
async for typ, group in self._group_messages(msgs[start_index:]):
|
||||
match typ:
|
||||
case "tool_sequence":
|
||||
formatted_msgs.extend(
|
||||
await self._format_tool_sequence(group),
|
||||
)
|
||||
case "agent_message":
|
||||
formatted_msgs.extend(
|
||||
await self._format_agent_message(
|
||||
group,
|
||||
is_first_agent_message,
|
||||
),
|
||||
)
|
||||
is_first_agent_message = False
|
||||
|
||||
return formatted_msgs
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the LLM API.
|
||||
|
||||
.. note:: This is the default implementation. For certain LLM APIs
|
||||
with specific requirements, you may need to implement a custom
|
||||
formatting function to accommodate those particular needs.
|
||||
"""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_content_blocks("text"),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_tool_sequence is not implemented",
|
||||
)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_agent_message is not implemented",
|
||||
)
|
||||
|
||||
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
|
||||
"""Truncate the input messages, so that it can fit the token limit.
|
||||
This function is called only when
|
||||
|
||||
- both `token_counter` and `max_tokens` are provided,
|
||||
- the formatted output of the input messages exceeds the token limit.
|
||||
|
||||
.. tip:: This function only provides a simple strategy, and developers
|
||||
can override this method to implement more sophisticated
|
||||
truncation strategies.
|
||||
|
||||
.. note:: The tool call message should be truncated together with
|
||||
its corresponding tool result message to satisfy the LLM API
|
||||
requirements.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be truncated.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the system prompt message already exceeds the token limit,
|
||||
or if there are tool calls without corresponding tool results.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The truncated messages.
|
||||
"""
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
if len(msgs) == 1:
|
||||
# If the system prompt already exceeds the token limit, we
|
||||
# raise an error.
|
||||
raise ValueError(
|
||||
f"The system prompt message already exceeds the token "
|
||||
f"limit ({self.max_tokens} tokens).",
|
||||
)
|
||||
|
||||
start_index = 1
|
||||
|
||||
# Create a tool call IDs queues to delete the corresponding tool
|
||||
# result message
|
||||
tool_call_ids = set()
|
||||
for i in range(start_index, len(msgs)):
|
||||
msg = msgs[i]
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_ids.add(block["id"])
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
try:
|
||||
tool_call_ids.remove(block["id"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We can stop truncating if the queue is empty
|
||||
if len(tool_call_ids) == 0:
|
||||
return msgs[:start_index] + msgs[i + 1 :]
|
||||
|
||||
if len(tool_call_ids) > 0:
|
||||
raise ValueError(
|
||||
"The input messages contains tool call(s) that do not have "
|
||||
f"the corresponding tool result(s): {tool_call_ids}. ",
|
||||
)
|
||||
|
||||
return msgs[:start_index]
|
||||
|
||||
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
|
||||
"""Count the number of tokens in the input messages. If token counter
|
||||
is not provided, `None` will be returned.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to count tokens for.
|
||||
"""
|
||||
if self.token_counter is None:
|
||||
return None
|
||||
|
||||
return await self.token_counter.count(msgs)
|
||||
|
||||
@staticmethod
|
||||
async def _group_messages(
|
||||
msgs: list[Msg],
|
||||
) -> AsyncGenerator[
|
||||
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
|
||||
None,
|
||||
]:
|
||||
"""Group the input messages into two types and yield them as a
|
||||
generator. The two types are:
|
||||
|
||||
- agent message that doesn't contain tool calls/results, and
|
||||
- tool sequence that consisted of a sequence of tool calls/results
|
||||
|
||||
.. note:: The group operation is used in multi-agent scenario, where
|
||||
multiple entities are involved in the input messages. So that to be
|
||||
compatible with tools API, we have to group the messages and format
|
||||
them with different strategies.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be grouped, where the system prompt
|
||||
message shouldn't be included.
|
||||
|
||||
Yields:
|
||||
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
|
||||
A generator that yields tuples of group type and the list of
|
||||
messages in that group. The group type can be either
|
||||
"tool_sequence" or "agent_message".
|
||||
"""
|
||||
|
||||
group_type: Literal["tool_sequence", "agent_message"] | None = None
|
||||
group = []
|
||||
for msg in msgs:
|
||||
if group_type is None:
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group_type = "tool_sequence"
|
||||
else:
|
||||
group_type = "agent_message"
|
||||
|
||||
group.append(msg)
|
||||
continue
|
||||
|
||||
# determine if this msg has the same type as the current group
|
||||
if group_type == "tool_sequence":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group.append(msg)
|
||||
|
||||
else:
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "agent_message"
|
||||
|
||||
elif group_type == "agent_message":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "tool_sequence"
|
||||
|
||||
else:
|
||||
group.append(msg)
|
||||
if group_type:
|
||||
yield group_type, group
|
||||
29
src/agentscope/hooks/__init__.py
Normal file
29
src/agentscope/hooks/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The built-in hook functions in agentscope."""
|
||||
from functools import partial
|
||||
|
||||
from ._studio_hooks import (
|
||||
as_studio_forward_message_pre_print_hook,
|
||||
)
|
||||
from .. import _config
|
||||
from ..agent import AgentBase
|
||||
|
||||
|
||||
__all__ = [
|
||||
"as_studio_forward_message_pre_print_hook",
|
||||
]
|
||||
|
||||
|
||||
def _equip_as_studio_hooks(
|
||||
studio_url: str,
|
||||
) -> None:
|
||||
"""Connect to the agentscope studio."""
|
||||
AgentBase.register_class_hook(
|
||||
"pre_print",
|
||||
"as_studio_forward_message_pre_print_hook",
|
||||
partial(
|
||||
as_studio_forward_message_pre_print_hook,
|
||||
studio_url=studio_url,
|
||||
run_id=_config.run_id,
|
||||
),
|
||||
)
|
||||
53
src/agentscope/hooks/_studio_hooks.py
Normal file
53
src/agentscope/hooks/_studio_hooks.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The studio related hook functions in agentscope."""
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import shortuuid
|
||||
|
||||
from ..agent import AgentBase, UserAgent
|
||||
|
||||
|
||||
def as_studio_forward_message_pre_print_hook(
|
||||
self: AgentBase,
|
||||
kwargs: dict[str, Any],
|
||||
studio_url: str,
|
||||
run_id: str,
|
||||
) -> None:
|
||||
"""The pre-speak hook to forward messages to the studio."""
|
||||
# Disable console output if needed
|
||||
if self._disable_console_output: # pylint: disable=protected-access
|
||||
return
|
||||
|
||||
msg = kwargs["msg"]
|
||||
|
||||
message_data = msg.to_dict()
|
||||
|
||||
if hasattr(self, "_reply_id"):
|
||||
reply_id = getattr(self, "_reply_id")
|
||||
else:
|
||||
reply_id = shortuuid.uuid()
|
||||
|
||||
n_retry = 0
|
||||
while True:
|
||||
try:
|
||||
res = requests.post(
|
||||
f"{studio_url}/trpc/pushMessage",
|
||||
json={
|
||||
"runId": run_id,
|
||||
"replyId": reply_id,
|
||||
"replyName": getattr(self, "name", msg.name),
|
||||
"replyRole": "user"
|
||||
if isinstance(self, UserAgent)
|
||||
else "assistant",
|
||||
"msg": message_data,
|
||||
},
|
||||
)
|
||||
res.raise_for_status()
|
||||
break
|
||||
except Exception as e:
|
||||
if n_retry < 3:
|
||||
n_retry += 1
|
||||
continue
|
||||
|
||||
raise e from None
|
||||
20
src/agentscope/mcp/__init__.py
Normal file
20
src/agentscope/mcp/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The MCP module in AgentScope, that provides fine-grained control over
|
||||
the MCP servers."""
|
||||
|
||||
from ._client_base import MCPClientBase
|
||||
from ._mcp_function import MCPToolFunction
|
||||
from ._stateful_client_base import StatefulClientBase
|
||||
from ._stdio_stateful_client import StdIOStatefulClient
|
||||
from ._http_stateless_client import HttpStatelessClient
|
||||
from ._http_stateful_client import HttpStatefulClient
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MCPToolFunction",
|
||||
"MCPClientBase",
|
||||
"StatefulClientBase",
|
||||
"StdIOStatefulClient",
|
||||
"HttpStatelessClient",
|
||||
"HttpStatefulClient",
|
||||
]
|
||||
101
src/agentscope/mcp/_client_base.py
Normal file
101
src/agentscope/mcp/_client_base.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base class for MCP clients in AgentScope."""
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List
|
||||
|
||||
import mcp.types
|
||||
|
||||
from .._logging import logger
|
||||
from ..message import (
|
||||
ImageBlock,
|
||||
Base64Source,
|
||||
AudioBlock,
|
||||
TextBlock,
|
||||
VideoBlock,
|
||||
)
|
||||
|
||||
|
||||
class MCPClientBase:
|
||||
"""Base class for MCP clients."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize the MCP client with a name.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name to identify the MCP server, which should be unique
|
||||
across the MCP servers.
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@abstractmethod
|
||||
async def get_callable_function(
|
||||
self,
|
||||
func_name: str,
|
||||
wrap_tool_result: bool = True,
|
||||
) -> Callable:
|
||||
"""Get a tool function by its name."""
|
||||
|
||||
@staticmethod
|
||||
def _convert_mcp_content_to_as_blocks(
|
||||
mcp_content_blocks: list,
|
||||
) -> List[TextBlock | ImageBlock | AudioBlock | VideoBlock]:
|
||||
"""Convert MCP content to AgentScope blocks."""
|
||||
|
||||
as_content: list = []
|
||||
for content in mcp_content_blocks:
|
||||
if isinstance(content, mcp.types.TextContent):
|
||||
as_content.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content.text,
|
||||
),
|
||||
)
|
||||
elif isinstance(content, mcp.types.ImageContent):
|
||||
as_content.append(
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=Base64Source(
|
||||
type="base64",
|
||||
media_type=content.mimeType,
|
||||
data=content.data,
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(content, mcp.types.AudioContent):
|
||||
as_content.append(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
type="base64",
|
||||
media_type=content.mimeType,
|
||||
data=content.data,
|
||||
),
|
||||
),
|
||||
)
|
||||
elif isinstance(content, mcp.types.EmbeddedResource):
|
||||
if isinstance(
|
||||
content.resource,
|
||||
mcp.types.TextResourceContents,
|
||||
):
|
||||
as_content.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content.resource.model_dump_json(indent=2),
|
||||
),
|
||||
)
|
||||
else:
|
||||
# TODO: support the BlobResourceContents in the future,
|
||||
# which is a base64-encoded string representing the
|
||||
# binary data
|
||||
logger.error(
|
||||
"Unsupported EmbeddedResource content type: %s. "
|
||||
"Skipping this content.",
|
||||
type(content.resource),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported content type: %s. Skipping this content.",
|
||||
type(content),
|
||||
)
|
||||
return as_content
|
||||
84
src/agentscope/mcp/_http_stateful_client.py
Normal file
84
src/agentscope/mcp/_http_stateful_client.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The MCP stateful HTTP client module in AgentScope."""
|
||||
from typing import Any, Literal
|
||||
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from ._stateful_client_base import StatefulClientBase
|
||||
|
||||
|
||||
class HttpStatefulClient(StatefulClientBase):
|
||||
"""The stateful sse/streamable HTTP MCP client implementation in
|
||||
AgentScope.
|
||||
|
||||
.. tip:: The stateful client is recommended for MCP servers that need to
|
||||
maintain session states, e.g. web browsers or other interactive
|
||||
MCP servers.
|
||||
|
||||
.. note:: The stateful client will maintain one session across multiple
|
||||
tool calls, until the client is closed by explicitly calling the
|
||||
`close()` method.
|
||||
|
||||
.. note:: When multiple HttpStatefulClient instances are connected,
|
||||
they should be closed following the Last In First Out (LIFO) principle
|
||||
to avoid potential errors. Always close the most recently registered
|
||||
client first, then work backwards to the first one.
|
||||
For more details, please refer to this `issue
|
||||
<https://github.com/modelcontextprotocol/python-sdk/issues/577>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
transport: Literal["streamable_http", "sse"],
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 30,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
**client_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the streamable HTTP MCP client.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name to identify the MCP server, which should be unique
|
||||
across the MCP servers.
|
||||
transport (`Literal["streamable_http", "sse"]`):
|
||||
The transport type of MCP server. Generally, the URL of sse
|
||||
transport should end with `/sse`, while the streamable HTTP
|
||||
URL ends with `/mcp`.
|
||||
url (`str`):
|
||||
The URL to the MCP server.
|
||||
headers (`dict[str, str] | None`, optional):
|
||||
Additional headers to include in the HTTP request.
|
||||
timeout (`float`, optional):
|
||||
The timeout for the HTTP request in seconds. Defaults to 30.
|
||||
sse_read_timeout (`float`, optional):
|
||||
The timeout for reading Server-Sent Events (SSE) in seconds.
|
||||
Defaults to 300 (5 minutes).
|
||||
**client_kwargs (`Any`):
|
||||
The additional keyword arguments to pass to the streamable
|
||||
HTTP client.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
assert transport in ["streamable_http", "sse"]
|
||||
self.transport = transport
|
||||
|
||||
if self.transport == "streamable_http":
|
||||
self.client = streamablehttp_client(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
**client_kwargs,
|
||||
)
|
||||
else:
|
||||
self.client = sse_client(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
**client_kwargs,
|
||||
)
|
||||
152
src/agentscope/mcp/_http_stateless_client.py
Normal file
152
src/agentscope/mcp/_http_stateless_client.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The MCP streamable HTTP server."""
|
||||
from contextlib import _AsyncGeneratorContextManager
|
||||
from typing import Any, Callable, Awaitable, Literal, List
|
||||
|
||||
import mcp.types
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from . import MCPToolFunction
|
||||
from ._client_base import MCPClientBase
|
||||
from ..tool import ToolResponse
|
||||
|
||||
|
||||
class HttpStatelessClient(MCPClientBase):
|
||||
"""The sse/streamable HTTP MCP client implementation in AgentScope.
|
||||
|
||||
.. note:: Note this client is stateless, meaning it won't maintain the
|
||||
session state across multiple tool calls. Each tool call will start a
|
||||
new session and close it after the call is done.
|
||||
|
||||
"""
|
||||
|
||||
stateful: bool = False
|
||||
"""Whether the MCP server is stateful, meaning it will maintain the
|
||||
session state across multiple tool calls, or stateless, meaning it
|
||||
will start a new session for each tool call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
transport: Literal["streamable_http", "sse"],
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 30,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
**client_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the streamable HTTP MCP server.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name to identify the MCP server, which should be unique
|
||||
across the MCP servers.
|
||||
transport (`Literal["streamable_http", "sse"]`):
|
||||
The transport type of MCP server. Generally, the URL of sse
|
||||
transport should end with `/sse`, while the streamable HTTP
|
||||
URL ends with `/mcp`.
|
||||
url (`str`):
|
||||
The URL of the MCP server.
|
||||
headers (`dict[str, str] | None`, optional):
|
||||
Additional headers to include in the HTTP request.
|
||||
timeout (`float`, optional):
|
||||
The timeout for the HTTP request in seconds. Defaults to 30.
|
||||
sse_read_timeout (`float`, optional):
|
||||
The timeout for reading Server-Sent Events (SSE) in seconds.
|
||||
Defaults to 300 (5 minutes).
|
||||
**client_kwargs (`Any`):
|
||||
The additional keyword arguments to pass to the streamable
|
||||
HTTP client.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
assert transport in ["streamable_http", "sse"]
|
||||
|
||||
self.transport = transport
|
||||
|
||||
self.client_config = {
|
||||
"url": url,
|
||||
"headers": headers or {},
|
||||
"timeout": timeout,
|
||||
"sse_read_timeout": sse_read_timeout,
|
||||
**client_kwargs,
|
||||
}
|
||||
|
||||
self._tools = None
|
||||
|
||||
def get_client(self) -> _AsyncGeneratorContextManager[Any]:
|
||||
"""The disposable MCP client object, which is a context manager."""
|
||||
if self.transport == "sse":
|
||||
return sse_client(**self.client_config)
|
||||
|
||||
if self.transport == "streamable_http":
|
||||
return streamablehttp_client(**self.client_config)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported transport type: {self.transport}. "
|
||||
"Supported types are 'sse' and 'streamable_http'.",
|
||||
)
|
||||
|
||||
async def get_callable_function(
|
||||
self,
|
||||
func_name: str,
|
||||
wrap_tool_result: bool = True,
|
||||
execution_timeout: float | None = None,
|
||||
) -> Callable[..., Awaitable[mcp.types.CallToolResult | ToolResponse]]:
|
||||
"""Get a tool function by its name.
|
||||
|
||||
Args:
|
||||
func_name (`str`):
|
||||
The name of the tool function.
|
||||
wrap_tool_result (`bool`, defaults to `True`):
|
||||
Whether to wrap the tool result into agentscope's
|
||||
`ToolResponse` object. If `False`, the raw result type
|
||||
`mcp.types.CallToolResult` will be returned.
|
||||
execution_timeout (`float | None`, optional):
|
||||
The preset timeout in seconds for calling the tool function.
|
||||
|
||||
Returns:
|
||||
`Callable[..., Awaitable[mcp.types.CallToolResult | \
|
||||
ToolResponse]]`:
|
||||
An async tool function that returns either
|
||||
`mcp.types.CallToolResult` or `ToolResponse` when called.
|
||||
"""
|
||||
|
||||
if self._tools is None:
|
||||
await self.list_tools()
|
||||
|
||||
target_tool = None
|
||||
for tool in self._tools:
|
||||
if tool.name == func_name:
|
||||
target_tool = tool
|
||||
break
|
||||
|
||||
if target_tool is None:
|
||||
raise ValueError(
|
||||
f"Tool '{func_name}' not found in the MCP server ",
|
||||
)
|
||||
|
||||
return MCPToolFunction(
|
||||
mcp_name=self.name,
|
||||
tool=target_tool,
|
||||
wrap_tool_result=wrap_tool_result,
|
||||
client_gen=self.get_client,
|
||||
timeout=execution_timeout,
|
||||
)
|
||||
|
||||
async def list_tools(self) -> List[mcp.types.Tool]:
|
||||
"""List all tools available on the MCP server.
|
||||
|
||||
Returns:
|
||||
`mcp.types.ListToolsResult`:
|
||||
The result containing the list of tools.
|
||||
"""
|
||||
async with self.get_client() as cli:
|
||||
read_stream, write_stream = cli[0], cli[1]
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
res = await session.list_tools()
|
||||
self._tools = res.tools
|
||||
return res.tools
|
||||
115
src/agentscope/mcp/_mcp_function.py
Normal file
115
src/agentscope/mcp/_mcp_function.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The MCP tool function class in AgentScope."""
|
||||
from contextlib import _AsyncGeneratorContextManager
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable
|
||||
|
||||
import mcp
|
||||
from mcp import ClientSession
|
||||
|
||||
from ._client_base import MCPClientBase
|
||||
from .._utils._common import _extract_json_schema_from_mcp_tool
|
||||
from ..tool import ToolResponse
|
||||
|
||||
|
||||
class MCPToolFunction:
|
||||
"""An MCP tool function class that can be called directly."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool function."""
|
||||
|
||||
description: str
|
||||
"""The description of the tool function."""
|
||||
|
||||
json_schema: dict[str, Any]
|
||||
"""JSON schema of the tool function"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_name: str,
|
||||
tool: mcp.types.Tool,
|
||||
wrap_tool_result: bool,
|
||||
client_gen: Callable[..., _AsyncGeneratorContextManager[Any]]
|
||||
| None = None,
|
||||
session: ClientSession | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the MCP function.
|
||||
|
||||
Args:
|
||||
mcp_name (`str`):
|
||||
The name of the MCP instance.
|
||||
tool (`mcp.types.Tool`):
|
||||
The MCP tool definition.
|
||||
wrap_tool_result (`bool`):
|
||||
Whether to wrap the tool result into `ToolResponse` in
|
||||
AgentScope.
|
||||
client_gen (`Callable[..., _AsyncGeneratorContextManager[Any]] | \
|
||||
None`, *optional*):
|
||||
The MCP client generator function. Either this or `session`
|
||||
must be provided.
|
||||
session (`ClientSession | None`, *optional*):
|
||||
The MCP client session. Either this or `client_gen` must be
|
||||
provided.
|
||||
timeout (`float | None`, *optional*):
|
||||
The timeout in seconds for tool execution. If not provided,
|
||||
no timeout will be set.
|
||||
"""
|
||||
self.mcp_name = mcp_name
|
||||
self.name = tool.name
|
||||
self.description = tool.description
|
||||
self.json_schema = _extract_json_schema_from_mcp_tool(tool)
|
||||
self.wrap_tool_result = wrap_tool_result
|
||||
|
||||
if timeout:
|
||||
self.timeout = timedelta(seconds=timeout)
|
||||
else:
|
||||
self.timeout = None
|
||||
|
||||
# Cannot be None at the same time
|
||||
if (
|
||||
client_gen is None
|
||||
and session is None
|
||||
or (client_gen is not None and session is not None)
|
||||
):
|
||||
raise ValueError(
|
||||
"Either client or session must be provided, but not both.",
|
||||
)
|
||||
|
||||
self.client_gen = client_gen
|
||||
self.session = session
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> mcp.types.CallToolResult | ToolResponse:
|
||||
"""Call the MCP tool function with the given arguments, and return
|
||||
the result."""
|
||||
if self.client_gen:
|
||||
async with self.client_gen() as cli:
|
||||
read_stream, write_stream = cli[0], cli[1]
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
res = await session.call_tool(
|
||||
self.name,
|
||||
arguments=kwargs,
|
||||
read_timeout_seconds=self.timeout,
|
||||
)
|
||||
|
||||
else:
|
||||
res = await self.session.call_tool(
|
||||
self.name,
|
||||
arguments=kwargs,
|
||||
read_timeout_seconds=self.timeout,
|
||||
)
|
||||
|
||||
if self.wrap_tool_result:
|
||||
as_content = MCPClientBase._convert_mcp_content_to_as_blocks(
|
||||
res.content,
|
||||
)
|
||||
return ToolResponse(
|
||||
content=as_content,
|
||||
metadata=res.meta,
|
||||
)
|
||||
|
||||
return res
|
||||
176
src/agentscope/mcp/_stateful_client_base.py
Normal file
176
src/agentscope/mcp/_stateful_client_base.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The base MCP stateful client class in AgentScope, that provides basic
|
||||
functionality for stateful MCP clients."""
|
||||
from abc import ABC
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import List
|
||||
|
||||
import mcp
|
||||
from mcp import ClientSession
|
||||
|
||||
from ._client_base import MCPClientBase
|
||||
from ._mcp_function import MCPToolFunction
|
||||
from .._logging import logger
|
||||
|
||||
|
||||
class StatefulClientBase(MCPClientBase, ABC):
|
||||
"""The base class for stateful MCP clients in AgentScope, which maintains
|
||||
the session state across multiple tool calls.
|
||||
|
||||
The developers should use `connect()` and `close()` methods to manage
|
||||
the client lifecycle.
|
||||
"""
|
||||
|
||||
is_connected: bool
|
||||
"""If connected to the MCP server"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialize the stateful MCP client.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name to identify the MCP server, which should be unique
|
||||
across the MCP servers.
|
||||
"""
|
||||
|
||||
super().__init__(name=name)
|
||||
|
||||
self.client = None
|
||||
self.stack = None
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
|
||||
# Cache the tools to avoid fetching them multiple times
|
||||
self._cached_tools = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to MCP server."""
|
||||
if self.is_connected:
|
||||
raise RuntimeError(
|
||||
"The MCP server is already connected. Call close() "
|
||||
"before connecting again.",
|
||||
)
|
||||
|
||||
self.stack = AsyncExitStack()
|
||||
|
||||
try:
|
||||
context = await self.stack.enter_async_context(
|
||||
self.client,
|
||||
)
|
||||
read_stream, write_stream = context[0], context[1]
|
||||
self.session = ClientSession(read_stream, write_stream)
|
||||
await self.stack.enter_async_context(self.session)
|
||||
await self.session.initialize()
|
||||
|
||||
self.is_connected = True
|
||||
logger.info("MCP client connected.")
|
||||
except Exception:
|
||||
await self.stack.aclose()
|
||||
self.stack = None
|
||||
raise
|
||||
|
||||
async def close(self, ignore_errors: bool = True) -> None:
|
||||
"""Clean up the MCP client resources. You must call this method when
|
||||
your application is done.
|
||||
|
||||
Args:
|
||||
ignore_errors (`bool`):
|
||||
Whether to ignore errors during cleanup. Defaults to `True`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RuntimeError(
|
||||
"The MCP server is not connected. Call connect() before "
|
||||
"closing.",
|
||||
)
|
||||
|
||||
try:
|
||||
await self.stack.aclose()
|
||||
except Exception as e:
|
||||
if not ignore_errors:
|
||||
raise e
|
||||
logger.warning("Error during MCP client cleanup: %s", e)
|
||||
finally:
|
||||
self.stack = None
|
||||
self.session = None
|
||||
self.is_connected = False
|
||||
|
||||
async def list_tools(self) -> List[mcp.types.Tool]:
|
||||
"""Get all available tools from the server.
|
||||
|
||||
Returns:
|
||||
`mcp.types.ListToolsResult`:
|
||||
A list of available MCP tools.
|
||||
"""
|
||||
self._validate_connection()
|
||||
|
||||
res = await self.session.list_tools()
|
||||
|
||||
# Cache the tools for later use
|
||||
self._cached_tools = res.tools
|
||||
return res.tools
|
||||
|
||||
async def get_callable_function(
|
||||
self,
|
||||
func_name: str,
|
||||
wrap_tool_result: bool = True,
|
||||
execution_timeout: float | None = None,
|
||||
) -> MCPToolFunction:
|
||||
"""Get an async tool function from the MCP server by its name, so
|
||||
that you can call it directly, wrap it into your own function, or
|
||||
anyway you like.
|
||||
|
||||
.. note:: Currently, only the text, image, and audio results are
|
||||
supported in this function.
|
||||
|
||||
Args:
|
||||
func_name (`str`):
|
||||
The name of the tool function to get.
|
||||
wrap_tool_result (`bool`):
|
||||
Whether to wrap the tool result into agentscope's
|
||||
`ToolResponse` object. If `False`, the raw result type
|
||||
`mcp.types.CallToolResult` will be returned.
|
||||
execution_timeout (`float | None`, optional):
|
||||
The preset timeout in seconds for calling the tool function.
|
||||
|
||||
Returns:
|
||||
`MCPToolFunction`:
|
||||
A callable async function that returns either
|
||||
`mcp.types.CallToolResult` or `ToolResponse` when called.
|
||||
"""
|
||||
self._validate_connection()
|
||||
|
||||
if self._cached_tools is None:
|
||||
await self.list_tools()
|
||||
|
||||
target_tool = None
|
||||
for tool in self._cached_tools:
|
||||
if tool.name == func_name:
|
||||
target_tool = tool
|
||||
break
|
||||
|
||||
if target_tool is None:
|
||||
raise ValueError(
|
||||
f"Tool '{func_name}' not found in the MCP server",
|
||||
)
|
||||
|
||||
return MCPToolFunction(
|
||||
mcp_name=self.name,
|
||||
tool=target_tool,
|
||||
wrap_tool_result=wrap_tool_result,
|
||||
session=self.session,
|
||||
timeout=execution_timeout,
|
||||
)
|
||||
|
||||
def _validate_connection(self) -> None:
|
||||
"""Validate the connection to the MCP server."""
|
||||
if not self.is_connected:
|
||||
raise RuntimeError(
|
||||
"The connection is not established. Call connect() "
|
||||
"before using the client.",
|
||||
)
|
||||
|
||||
if not self.session:
|
||||
raise RuntimeError(
|
||||
"The session is not initialized. Call connect() "
|
||||
"before using the client.",
|
||||
)
|
||||
77
src/agentscope/mcp/_stdio_stateful_client.py
Normal file
77
src/agentscope/mcp/_stdio_stateful_client.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The StdIO MCP server implementation in AgentScope, which provides
|
||||
function-level fine-grained control over the MCP servers using standard IO."""
|
||||
from typing import Literal
|
||||
|
||||
from mcp import stdio_client, StdioServerParameters
|
||||
|
||||
from ._stateful_client_base import StatefulClientBase
|
||||
|
||||
|
||||
class StdIOStatefulClient(StatefulClientBase):
|
||||
"""A client class that sets up and manage StdIO MCP server connections, and
|
||||
provides function-level fine-grained control over the MCP servers.
|
||||
|
||||
.. tip:: The stateful client is recommended for MCP servers that need to
|
||||
maintain session states, e.g. web browsers or other interactive
|
||||
MCP servers.
|
||||
|
||||
.. note:: The stateful client will maintain one session across multiple
|
||||
tool calls, until the client is closed by explicitly calling the
|
||||
`close()` method.
|
||||
|
||||
.. note:: When multiple StdIOStatefulClient instances are connected,
|
||||
they should be closed following the Last In First Out (LIFO) principle
|
||||
to avoid potential errors. Always close the most recently registered
|
||||
client first, then work backwards to the first one.
|
||||
For more details, please refer to this `issue
|
||||
<https://github.com/modelcontextprotocol/python-sdk/issues/577>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
command: str,
|
||||
args: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
cwd: str | None = None,
|
||||
encoding: str = "utf-8",
|
||||
encoding_error_handler: Literal[
|
||||
"strict",
|
||||
"ignore",
|
||||
"replace",
|
||||
] = "strict",
|
||||
) -> None:
|
||||
"""Initialize the MCP server with std IO.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name to identify the MCP server, which should be unique
|
||||
across the MCP servers.
|
||||
command (`str`):
|
||||
The executable to run to start the server.
|
||||
args (`list[str] | None`, optional):
|
||||
Command line arguments to pass to the executable.
|
||||
env (`dict[str, str] | None`, optional):
|
||||
The environment to use when spawning the process.
|
||||
cwd (`str | None`, optional):
|
||||
The working directory to use when spawning the process.
|
||||
encoding (`str`, optional):
|
||||
The text encoding used when sending/receiving messages to the
|
||||
server. Defaults to "utf-8".
|
||||
encoding_error_handler (`Literal["strict", "ignore", "replace"]`, \
|
||||
defaults to "strict"):
|
||||
The text encoding error handler.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
self.client = stdio_client(
|
||||
StdioServerParameters(
|
||||
command=command,
|
||||
args=args or [],
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
encoding=encoding,
|
||||
encoding_error_handler=encoding_error_handler,
|
||||
),
|
||||
)
|
||||
31
src/agentscope/memory/__init__.py
Normal file
31
src/agentscope/memory/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The memory module."""
|
||||
|
||||
from ._working_memory import (
|
||||
MemoryBase,
|
||||
InMemoryMemory,
|
||||
RedisMemory,
|
||||
AsyncSQLAlchemyMemory,
|
||||
)
|
||||
from ._long_term_memory import (
|
||||
LongTermMemoryBase,
|
||||
Mem0LongTermMemory,
|
||||
ReMePersonalLongTermMemory,
|
||||
ReMeTaskLongTermMemory,
|
||||
ReMeToolLongTermMemory,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Working memory
|
||||
"MemoryBase",
|
||||
"InMemoryMemory",
|
||||
"RedisMemory",
|
||||
"AsyncSQLAlchemyMemory",
|
||||
# Long-term memory
|
||||
"LongTermMemoryBase",
|
||||
"Mem0LongTermMemory",
|
||||
"ReMePersonalLongTermMemory",
|
||||
"ReMeTaskLongTermMemory",
|
||||
"ReMeToolLongTermMemory",
|
||||
]
|
||||
18
src/agentscope/memory/_long_term_memory/__init__.py
Normal file
18
src/agentscope/memory/_long_term_memory/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The long-term memory module for AgentScope."""
|
||||
|
||||
from ._long_term_memory_base import LongTermMemoryBase
|
||||
from ._mem0 import Mem0LongTermMemory
|
||||
from ._reme import (
|
||||
ReMePersonalLongTermMemory,
|
||||
ReMeTaskLongTermMemory,
|
||||
ReMeToolLongTermMemory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LongTermMemoryBase",
|
||||
"Mem0LongTermMemory",
|
||||
"ReMePersonalLongTermMemory",
|
||||
"ReMeTaskLongTermMemory",
|
||||
"ReMeToolLongTermMemory",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The long-term memory base class."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agentscope.message import Msg
|
||||
from agentscope.module import StateModule
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
|
||||
class LongTermMemoryBase(StateModule):
|
||||
"""The long-term memory base class, which should be a time-series
|
||||
memory management system.
|
||||
|
||||
The `record_to_memory` and `retrieve_from_memory` methods are two tool
|
||||
functions for agent to manage the long-term memory voluntarily. You can
|
||||
choose not to implement these two functions.
|
||||
|
||||
The `record` and `retrieve` methods are for developers to use. For example,
|
||||
retrieving/recording memory at the beginning of each reply, and adding
|
||||
the retrieved memory to the system prompt.
|
||||
"""
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""A developer-designed method to record information from the given
|
||||
input message(s) to the long-term memory."""
|
||||
raise NotImplementedError(
|
||||
"The `record` method is not implemented. ",
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""A developer-designed method to retrieve information from the
|
||||
long-term memory based on the given input message(s). The retrieved
|
||||
information will be added to the system prompt of the agent."""
|
||||
raise NotImplementedError(
|
||||
"The `retrieve` method is not implemented. ",
|
||||
)
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Use this function to record important information that you may
|
||||
need later. The target content should be specific and concise, e.g.
|
||||
who, when, where, do what, why, how, etc.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your thinking and reasoning about what to record
|
||||
content (`list[str]`):
|
||||
The content to remember, which is a list of strings.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The `record_to_memory` method is not implemented. "
|
||||
"You can implement it in your own long-term memory class.",
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Retrieve the memory based on the given keywords.
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
The keywords to search for in the memory, which should be
|
||||
specific and concise, e.g. the person's name, the date, the
|
||||
location, etc.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 5.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
A list of messages that match the keywords.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The `retrieve_from_memory` method is not implemented. "
|
||||
"You can implement it in your own long-term memory class.",
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The Mem0 long-term memory module for AgentScope."""
|
||||
|
||||
from ._mem0_long_term_memory import Mem0LongTermMemory
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Mem0LongTermMemory",
|
||||
]
|
||||
@@ -0,0 +1,746 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Long-term memory implementation using mem0 library.
|
||||
|
||||
This module provides a long-term memory implementation that integrates
|
||||
with the mem0 library to provide persistent memory storage and retrieval
|
||||
capabilities for AgentScope agents.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from importlib import metadata
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from ....embedding import EmbeddingModelBase
|
||||
from .._long_term_memory_base import LongTermMemoryBase
|
||||
from ....message import Msg, TextBlock
|
||||
from ....model import ChatModelBase
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mem0.configs.base import MemoryConfig
|
||||
from mem0.vector_stores.configs import VectorStoreConfig
|
||||
else:
|
||||
MemoryConfig = Any
|
||||
VectorStoreConfig = Any
|
||||
|
||||
|
||||
def _create_agentscope_config_classes() -> tuple:
|
||||
"""Create custom config classes for agentscope providers."""
|
||||
from mem0.embeddings.configs import EmbedderConfig
|
||||
from mem0.llms.configs import LlmConfig
|
||||
|
||||
class _ASLlmConfig(LlmConfig):
|
||||
"""Custom LLM config class that updates the validate_config method.
|
||||
|
||||
Attention: in mem0, the validate_config hardcodes the provider, so we
|
||||
need to override the validate_config method to support the agentscope
|
||||
providers. We will follow up with the mem0 to improve this.
|
||||
"""
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config(cls, v: Any, values: Any) -> Any:
|
||||
"""Validate the LLM configuration."""
|
||||
from mem0.utils.factory import LlmFactory
|
||||
|
||||
provider = values.data.get("provider")
|
||||
if provider in LlmFactory.provider_to_class:
|
||||
return v
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
class _ASEmbedderConfig(EmbedderConfig):
|
||||
"""Custom embedder config class that updates the validate_config
|
||||
method."""
|
||||
|
||||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config(cls, v: Any, values: Any) -> Any:
|
||||
"""Validate the embedder configuration."""
|
||||
from mem0.utils.factory import EmbedderFactory
|
||||
|
||||
provider = values.data.get("provider")
|
||||
if provider in EmbedderFactory.provider_to_class:
|
||||
return v
|
||||
raise ValueError(f"Unsupported Embedder provider: {provider}")
|
||||
|
||||
return _ASLlmConfig, _ASEmbedderConfig
|
||||
|
||||
|
||||
class Mem0LongTermMemory(LongTermMemoryBase):
|
||||
"""A class that implements the LongTermMemoryBase interface using mem0."""
|
||||
|
||||
@staticmethod
|
||||
def _setup_mem0_logging(suppress_mem0_logging: bool) -> None:
|
||||
"""Suppress mem0 logging if requested.
|
||||
|
||||
Args:
|
||||
suppress_mem0_logging (`bool`):
|
||||
Whether to suppress mem0 logging. See class docstring for
|
||||
details on QDRANT validation errors when using mem0 1.0.3.
|
||||
"""
|
||||
if suppress_mem0_logging:
|
||||
import logging
|
||||
|
||||
logging.getLogger("mem0").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("mem0.memory").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("mem0.memory.main").setLevel(logging.CRITICAL)
|
||||
|
||||
@staticmethod
|
||||
def _register_agentscope_providers() -> None:
|
||||
"""Register the agentscope providers with mem0.
|
||||
|
||||
Raises:
|
||||
`ImportError`:
|
||||
If the mem0 library is not installed.
|
||||
"""
|
||||
try:
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from packaging import version
|
||||
|
||||
# Check mem0 version
|
||||
current_version = metadata.version("mem0ai")
|
||||
is_mem0_version_low = version.parse(
|
||||
current_version,
|
||||
) <= version.parse("0.1.115")
|
||||
|
||||
# Register the agentscope providers with mem0
|
||||
EmbedderFactory.provider_to_class["agentscope"] = (
|
||||
"agentscope.memory._long_term_memory._mem0."
|
||||
"_mem0_utils.AgentScopeEmbedding"
|
||||
)
|
||||
if is_mem0_version_low:
|
||||
# For mem0 version <= 0.1.115, use the old style
|
||||
LlmFactory.provider_to_class["agentscope"] = (
|
||||
"agentscope.memory._long_term_memory._mem0."
|
||||
"_mem0_utils.AgentScopeLLM"
|
||||
)
|
||||
else:
|
||||
# For mem0 version > 0.1.115, use the new style
|
||||
LlmFactory.provider_to_class["agentscope"] = (
|
||||
"agentscope.memory._long_term_memory._mem0."
|
||||
"_mem0_utils.AgentScopeLLM",
|
||||
BaseLlmConfig,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install the mem0 library by `pip install mem0ai`",
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def _validate_identifiers(
|
||||
agent_name: str | None,
|
||||
user_name: str | None,
|
||||
run_name: str | None,
|
||||
) -> None:
|
||||
"""Validate that at least one identifier is provided.
|
||||
|
||||
Args:
|
||||
agent_name (`str | None`):
|
||||
The name of the agent.
|
||||
user_name (`str | None`):
|
||||
The name of the user.
|
||||
run_name (`str | None`):
|
||||
The name of the run/session.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If all identifiers are None.
|
||||
"""
|
||||
if agent_name is None and user_name is None and run_name is None:
|
||||
raise ValueError(
|
||||
"at least one of agent_name, user_name, and run_name is "
|
||||
"required",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_mem0_config(
|
||||
mem0_config: MemoryConfig | None,
|
||||
model: ChatModelBase | None,
|
||||
embedding_model: EmbeddingModelBase | None,
|
||||
vector_store_config: VectorStoreConfig | None,
|
||||
_ASLlmConfig: type,
|
||||
_ASEmbedderConfig: type,
|
||||
**kwargs: Any,
|
||||
) -> MemoryConfig:
|
||||
"""Configure the mem0 MemoryConfig object.
|
||||
|
||||
Args:
|
||||
mem0_config (`MemoryConfig | None`):
|
||||
The existing mem0 config, if any.
|
||||
model (`ChatModelBase | None`):
|
||||
The chat model to use.
|
||||
embedding_model (`EmbeddingModelBase | None`):
|
||||
The embedding model to use.
|
||||
vector_store_config (`VectorStoreConfig | None`):
|
||||
The vector store config to use.
|
||||
_ASLlmConfig (`type`):
|
||||
The custom LLM config class for agentscope.
|
||||
_ASEmbedderConfig (`type`):
|
||||
The custom embedder config class for agentscope.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments, including 'on_disk' for
|
||||
vector store configuration.
|
||||
|
||||
Returns:
|
||||
`MemoryConfig`:
|
||||
The configured MemoryConfig object.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If `mem0_config` is None and either `model` or
|
||||
`embedding_model` is None.
|
||||
"""
|
||||
import mem0
|
||||
|
||||
if mem0_config is not None:
|
||||
# Case 1: mem0_config is provided - override specific
|
||||
# configurations if individual params are given
|
||||
|
||||
# Override LLM configuration if model is provided
|
||||
if model is not None:
|
||||
mem0_config.llm = _ASLlmConfig(
|
||||
provider="agentscope",
|
||||
config={"model": model},
|
||||
)
|
||||
|
||||
# Override embedder configuration if embedding_model is provided
|
||||
if embedding_model is not None:
|
||||
mem0_config.embedder = _ASEmbedderConfig(
|
||||
provider="agentscope",
|
||||
config={"model": embedding_model},
|
||||
)
|
||||
|
||||
# Override vector store configuration if vector_store_config is
|
||||
# provided
|
||||
if vector_store_config is not None:
|
||||
mem0_config.vector_store = vector_store_config
|
||||
|
||||
else:
|
||||
# Case 2: mem0_config is not provided - create new configuration
|
||||
# from individual parameters
|
||||
|
||||
# Validate that required parameters are provided
|
||||
if model is None or embedding_model is None:
|
||||
raise ValueError(
|
||||
"model and embedding_model are required if mem0_config "
|
||||
"is not provided",
|
||||
)
|
||||
|
||||
# Create new MemoryConfig with provided LLM and embedder
|
||||
mem0_config = mem0.configs.base.MemoryConfig(
|
||||
llm=_ASLlmConfig(
|
||||
provider="agentscope",
|
||||
config={"model": model},
|
||||
),
|
||||
embedder=_ASEmbedderConfig(
|
||||
provider="agentscope",
|
||||
config={"model": embedding_model},
|
||||
),
|
||||
)
|
||||
|
||||
# Set vector store configuration
|
||||
if vector_store_config is not None:
|
||||
# Use provided vector store configuration
|
||||
mem0_config.vector_store = vector_store_config
|
||||
else:
|
||||
# Use default Qdrant configuration with on-disk storage for
|
||||
# persistence set on_disk to True to enable persistence,
|
||||
# otherwise it will be in memory only
|
||||
on_disk = kwargs.get("on_disk", True)
|
||||
mem0_config.vector_store = (
|
||||
mem0.vector_stores.configs.VectorStoreConfig(
|
||||
config={"on_disk": on_disk},
|
||||
)
|
||||
)
|
||||
|
||||
return mem0_config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
user_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
model: ChatModelBase | None = None,
|
||||
embedding_model: EmbeddingModelBase | None = None,
|
||||
vector_store_config: VectorStoreConfig | None = None,
|
||||
mem0_config: MemoryConfig | None = None,
|
||||
default_memory_type: str | None = None,
|
||||
suppress_mem0_logging: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Mem0LongTermMemory instance
|
||||
|
||||
Args:
|
||||
agent_name (`str | None`, optional):
|
||||
The name of the agent. Default is None.
|
||||
user_name (`str | None`, optional):
|
||||
The name of the user. Default is None.
|
||||
run_name (`str | None`, optional):
|
||||
The name of the run/session. Default is None.
|
||||
|
||||
.. note::
|
||||
1. At least one of `agent_name`, `user_name`, or `run_name` is
|
||||
required.
|
||||
2. During memory recording, these parameters become metadata
|
||||
for the stored memories.
|
||||
3. **Important**: mem0 will extract memories from messages
|
||||
containing role of "user" by default. If you want to
|
||||
extract memories from messages containing role of
|
||||
"assistant", you need to provide `agent_name`.
|
||||
4. During memory retrieval, only memories with matching
|
||||
metadata values will be returned.
|
||||
|
||||
|
||||
model (`ChatModelBase | None`, optional):
|
||||
The chat model to use for the long-term memory. If
|
||||
mem0_config is provided, this will override the LLM
|
||||
configuration. If mem0_config is None, this is required.
|
||||
embedding_model (`EmbeddingModelBase | None`, optional):
|
||||
The embedding model to use for the long-term memory. If
|
||||
mem0_config is provided, this will override the embedder
|
||||
configuration. If mem0_config is None, this is required.
|
||||
vector_store_config (`VectorStoreConfig | None`, optional):
|
||||
The vector store config to use for the long-term memory.
|
||||
If mem0_config is provided, this will override the vector store
|
||||
configuration. If mem0_config is None and this is not
|
||||
provided, defaults to Qdrant with on_disk=True.
|
||||
mem0_config (`MemoryConfig | None`, optional):
|
||||
The mem0 config to use for the long-term memory.
|
||||
If provided, individual
|
||||
model/embedding_model/vector_store_config parameters will
|
||||
override the corresponding configurations in mem0_config. If
|
||||
None, a new MemoryConfig will be created using the provided
|
||||
parameters.
|
||||
default_memory_type (`str | None`, optional):
|
||||
The type of memory to use. Default is None, to create a
|
||||
semantic memory.
|
||||
suppress_mem0_logging (`bool`, optional):
|
||||
Whether to suppress mem0 logging. Default is True.
|
||||
|
||||
.. note::
|
||||
When using vector database QDRANT with mem0 1.0.3, you may
|
||||
encounter validation errors:
|
||||
"Error awaiting memory task (async): 6 validation errors for
|
||||
PointStruct vector.list[float] Input should be a valid list
|
||||
[type=list_type, ...]"
|
||||
According to the mem0 community
|
||||
(see https://github.com/mem0ai/mem0/issues/3780),
|
||||
these error messages are harmless and can be safely ignored.
|
||||
Setting `suppress_mem0_logging=True` (the default) will
|
||||
suppress these error messages.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If `mem0_config` is None and either `model` or
|
||||
`embedding_model` is None.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Suppress mem0 logging if requested
|
||||
self._setup_mem0_logging(suppress_mem0_logging)
|
||||
|
||||
# Register agentscope providers with mem0
|
||||
self._register_agentscope_providers()
|
||||
|
||||
# Create the custom config classes for agentscope providers dynamically
|
||||
_ASLlmConfig, _ASEmbedderConfig = _create_agentscope_config_classes()
|
||||
|
||||
# Validate identifiers
|
||||
self._validate_identifiers(agent_name, user_name, run_name)
|
||||
|
||||
# Store agent and user identifiers for memory management
|
||||
self.agent_id = agent_name
|
||||
self.user_id = user_name
|
||||
self.run_id = run_name
|
||||
|
||||
# Configure mem0_config
|
||||
import mem0
|
||||
|
||||
mem0_config = self._configure_mem0_config(
|
||||
mem0_config=mem0_config,
|
||||
model=model,
|
||||
embedding_model=embedding_model,
|
||||
vector_store_config=vector_store_config,
|
||||
_ASLlmConfig=_ASLlmConfig,
|
||||
_ASEmbedderConfig=_ASEmbedderConfig,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize the async memory instance with the configured settings
|
||||
self.long_term_working_memory = mem0.AsyncMemory(mem0_config)
|
||||
|
||||
# Store the default memory type for future use
|
||||
self.default_memory_type = default_memory_type
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Use this function to record important information that you may
|
||||
need later. The target content should be specific and concise, e.g.
|
||||
who, when, where, do what, why, how, etc.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your thinking and reasoning about what to record.
|
||||
content (`list[str]`):
|
||||
The content to remember, which is a list of strings.
|
||||
"""
|
||||
# Multi-strategy recording approach to ensure content persistence:
|
||||
#
|
||||
# This method employs a three-tier fallback strategy to maximize
|
||||
# successful memory recording:
|
||||
#
|
||||
# 1. Primary: Record as "user" role message
|
||||
# - This is the default approach for capturing user-related
|
||||
# content
|
||||
# - Mem0 extracts and infers memories from messages containing
|
||||
# role of "user"
|
||||
#
|
||||
# 2. Fallback (if agent_id exists): Record as "assistant" role
|
||||
# message
|
||||
# - Triggered when primary recording yields no results
|
||||
# - In this case, mem0 will use the AGENT_MEMORY_EXTRACTION_PROMPT
|
||||
# in mem0/mem0/configs/prompts.py to extract memories from
|
||||
# messages containing role of "assistant", if agent_id is
|
||||
# provided, otherwise it will use the
|
||||
# USER_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
|
||||
# to extract memories.
|
||||
#
|
||||
# 3. Last resort: Record as "assistant" with infer=False
|
||||
# - Used when both previous attempts yield no results
|
||||
# - Bypasses mem0's inference mechanism, which means no
|
||||
# inference is performed, mem0 will only record the content
|
||||
# as is.
|
||||
#
|
||||
# This graduated approach ensures that even if mem0's inference fails
|
||||
# to extract meaningful memories, the raw content is still preserved.
|
||||
|
||||
try:
|
||||
if thinking:
|
||||
content = [thinking] + content
|
||||
|
||||
# Strategy 1: Record as user message first
|
||||
results = await self._mem0_record(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "\n".join(content),
|
||||
"name": "user",
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Strategy 2: Fallback to assistant message. In this case, if
|
||||
# agent_id is provided, mem0 will use the
|
||||
# AGENT_MEMORY_EXTRACTION_PROMPT in mem0/mem0/configs/prompts.py
|
||||
# to extract memories from messages containing role of
|
||||
# "assistant". If agent_id is not provided, mem0 will still use
|
||||
# the USER_MEMORY_EXTRACTION_PROMPT in
|
||||
# mem0/mem0/configs/prompts.py to extract memories.
|
||||
if (
|
||||
results
|
||||
and isinstance(results, dict)
|
||||
and "results" in results
|
||||
and len(results["results"]) == 0
|
||||
):
|
||||
results = await self._mem0_record(
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n".join(content),
|
||||
"name": "assistant",
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Strategy 3: Last resort - direct recording without inference.
|
||||
# In this case, mem0 will not use any prompts to extract
|
||||
# memories, it will only record the content as is.
|
||||
if (
|
||||
results
|
||||
and isinstance(results, dict)
|
||||
and "results" in results
|
||||
and len(results["results"]) == 0
|
||||
):
|
||||
results = await self._mem0_record(
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n".join(content),
|
||||
"name": "assistant",
|
||||
},
|
||||
],
|
||||
infer=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Successfully recorded content to memory "
|
||||
f"{results}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Retrieve the memory based on the given keywords.
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
Short, targeted search phrases (for example, a person's name,
|
||||
a specific date, a location, or a phrase describing something
|
||||
you want to retrieve from the memory). Each keyword is issued
|
||||
as an independent query against the memory store.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search,
|
||||
defaults to 5.
|
||||
i.e.,the number of memories to retrieve for
|
||||
each keyword.
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
A ToolResponse containing the retrieved memories as JSON text.
|
||||
"""
|
||||
|
||||
try:
|
||||
results = []
|
||||
search_coroutines = [
|
||||
self.long_term_working_memory.search(
|
||||
query=keyword,
|
||||
agent_id=self.agent_id,
|
||||
user_id=self.user_id,
|
||||
run_id=self.run_id,
|
||||
limit=limit,
|
||||
)
|
||||
for keyword in keywords
|
||||
]
|
||||
search_results = await asyncio.gather(*search_coroutines)
|
||||
for result in search_results:
|
||||
if result:
|
||||
results.extend(
|
||||
[item["memory"] for item in result["results"]],
|
||||
)
|
||||
if "relations" in result.keys():
|
||||
results.extend(
|
||||
self._format_relations(result),
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="\n".join(results),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
memory_type: str | None = None,
|
||||
infer: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Record the content to the long-term memory.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory.
|
||||
memory_type (`str | None`, optional):
|
||||
The type of memory to use. Default is None, to create a
|
||||
semantic memory. "procedural_memory" is explicitly used for
|
||||
procedural memories.
|
||||
infer (`bool`, optional):
|
||||
Whether to infer memory from the content. Default is True.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the mem0 recording.
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "\n".join([str(_.content) for _ in msg_list]),
|
||||
"name": "assistant",
|
||||
},
|
||||
]
|
||||
|
||||
results = await self._mem0_record(
|
||||
messages,
|
||||
memory_type=memory_type,
|
||||
infer=infer,
|
||||
**kwargs,
|
||||
)
|
||||
return results
|
||||
|
||||
def _format_relations(self, result: dict) -> list:
|
||||
"""Format relations from search result.
|
||||
|
||||
Args:
|
||||
result (`dict`):
|
||||
The result from the memory search operation.
|
||||
|
||||
Returns:
|
||||
`list`:
|
||||
The formatted relations.
|
||||
Each relation is a string in the format of:
|
||||
"{source} -- {relationship} -- {destination}"
|
||||
"""
|
||||
if "relations" not in result:
|
||||
return []
|
||||
return [
|
||||
f"{relation['source']} -- "
|
||||
f"{relation['relationship']} -- "
|
||||
f"{relation['destination']}"
|
||||
for relation in result["relations"]
|
||||
]
|
||||
|
||||
async def _mem0_record(
|
||||
self,
|
||||
messages: str | list[dict],
|
||||
memory_type: str | None = None,
|
||||
infer: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
"""Record the content to the long-term memory.
|
||||
|
||||
Args:
|
||||
messages (`str`):
|
||||
The content to remember, which is a string or a list of
|
||||
dictionaries representing messages.
|
||||
memory_type (`str | None`, optional):
|
||||
The type of memory to use. Default is None, to create a
|
||||
semantic memory. "procedural_memory" is explicitly used for
|
||||
procedural memories.
|
||||
infer (`bool`, optional):
|
||||
Whether to infer memory from the content. Default is True.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
The result from the memory recording operation.
|
||||
"""
|
||||
results = await self.long_term_working_memory.add(
|
||||
messages=messages,
|
||||
agent_id=self.agent_id,
|
||||
user_id=self.user_id,
|
||||
run_id=self.run_id,
|
||||
memory_type=(
|
||||
memory_type
|
||||
if memory_type is not None
|
||||
else self.default_memory_type
|
||||
),
|
||||
infer=infer,
|
||||
**kwargs,
|
||||
)
|
||||
return results
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve the content from the long-term memory.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message to search for in the memory, which should be
|
||||
specific and concise, e.g. the person's name, the date, the
|
||||
location, etc.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. if the
|
||||
message is a list of messages, the limit will be applied to
|
||||
each message. If the message is a single message, then the
|
||||
limit is the total number of memories to retrieve for the
|
||||
message. Defaults to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved memory
|
||||
"""
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
msg_strs = [
|
||||
json.dumps(_.to_dict()["content"], ensure_ascii=False) for _ in msg
|
||||
]
|
||||
|
||||
results = []
|
||||
search_coroutines = [
|
||||
self.long_term_working_memory.search(
|
||||
query=item,
|
||||
agent_id=self.agent_id,
|
||||
user_id=self.user_id,
|
||||
run_id=self.run_id,
|
||||
limit=limit,
|
||||
)
|
||||
for item in msg_strs
|
||||
]
|
||||
search_results = await asyncio.gather(*search_coroutines)
|
||||
for result in search_results:
|
||||
if result:
|
||||
results.extend(
|
||||
[memory["memory"] for memory in result["results"]],
|
||||
)
|
||||
if "relations" in result.keys():
|
||||
results.extend(
|
||||
self._format_relations(result),
|
||||
)
|
||||
|
||||
return "\n".join(results)
|
||||
363
src/agentscope/memory/_long_term_memory/_mem0/_mem0_utils.py
Normal file
363
src/agentscope/memory/_long_term_memory/_mem0/_mem0_utils.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Utility classes for integrating AgentScope with mem0 library.
|
||||
|
||||
This module provides wrapper classes that allow AgentScope models to be used
|
||||
with the mem0 library for long-term memory functionality.
|
||||
"""
|
||||
import asyncio
|
||||
import atexit
|
||||
import threading
|
||||
from typing import Any, Coroutine, Dict, List, Literal
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
from ....embedding import EmbeddingModelBase
|
||||
from ....model import ChatModelBase, ChatResponse
|
||||
|
||||
|
||||
class _EventLoopManager:
|
||||
"""Global event loop manager for running async operations in sync context.
|
||||
|
||||
This manager creates and maintains a persistent background event loop
|
||||
that runs in a separate daemon thread. This ensures that async model
|
||||
clients (like Ollama AsyncClient) remain bound to the same event loop
|
||||
across multiple calls, avoiding "Event loop is closed" errors.
|
||||
"""
|
||||
|
||||
_DEFAULT_TIMEOUT = 5.0 # Default timeout in seconds
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the event loop manager."""
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.thread: threading.Thread | None = None
|
||||
self.lock = threading.Lock()
|
||||
self.loop_started = threading.Event()
|
||||
|
||||
# Register cleanup function to be called on program exit
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
"""Get or create the persistent background event loop.
|
||||
|
||||
Returns:
|
||||
`asyncio.AbstractEventLoop`:
|
||||
The persistent event loop running in a background thread.
|
||||
|
||||
Raises:
|
||||
`RuntimeError`: If the event loop fails to start within the
|
||||
timeout.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.loop is None or self.loop.is_closed():
|
||||
|
||||
def run_loop() -> None:
|
||||
"""Run the event loop in the background thread."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# Store the loop reference before starting
|
||||
self.loop = loop
|
||||
self.loop_started.set()
|
||||
loop.run_forever()
|
||||
|
||||
# Clear the event before starting the thread
|
||||
self.loop_started.clear()
|
||||
|
||||
# Create and start the background thread
|
||||
self.thread = threading.Thread(
|
||||
target=run_loop,
|
||||
daemon=True,
|
||||
name="AgentScope-AsyncLoop",
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
# Wait for the loop to be ready
|
||||
if not self.loop_started.wait(timeout=self._DEFAULT_TIMEOUT):
|
||||
raise RuntimeError(
|
||||
"Timeout waiting for event loop to start",
|
||||
)
|
||||
|
||||
# After waiting, self.loop should be set by the background thread
|
||||
assert (
|
||||
self.loop is not None
|
||||
), "Event loop was not initialized properly"
|
||||
return self.loop
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup the event loop and thread on program exit."""
|
||||
with self.lock:
|
||||
if self.loop is not None and not self.loop.is_closed():
|
||||
# Stop the event loop gracefully
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
|
||||
# Wait for the thread to finish
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=self._DEFAULT_TIMEOUT)
|
||||
|
||||
# Close the loop
|
||||
self.loop.close()
|
||||
self.loop = None
|
||||
self.thread = None
|
||||
|
||||
|
||||
# Global event loop manager instance
|
||||
_event_loop_manager = _EventLoopManager()
|
||||
|
||||
|
||||
def _run_async_in_persistent_loop(coro: Coroutine) -> Any:
|
||||
"""Run an async coroutine in the persistent background event loop.
|
||||
|
||||
This function uses a global event loop manager to ensure that all
|
||||
async operations run in the same event loop, which is crucial for
|
||||
async clients like Ollama that bind to a specific event loop.
|
||||
|
||||
Args:
|
||||
coro (`Coroutine`):
|
||||
The coroutine to run.
|
||||
|
||||
Returns:
|
||||
`Any`:
|
||||
The result of the coroutine.
|
||||
|
||||
Raises:
|
||||
`RuntimeError`:
|
||||
If there's an error running the coroutine.
|
||||
"""
|
||||
loop = _event_loop_manager.get_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result()
|
||||
|
||||
|
||||
class AgentScopeLLM(LLMBase):
|
||||
"""Wrapper for the AgentScope LLM.
|
||||
|
||||
This class is a wrapper for the AgentScope LLM. It is used to generate
|
||||
responses using the AgentScope LLM in mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BaseLlmConfig | None = None):
|
||||
"""Initialize the AgentScopeLLM wrapper.
|
||||
|
||||
Args:
|
||||
config (`BaseLlmConfig | None`, optional):
|
||||
Configuration object for the LLM. Default is None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
raise ValueError("`model` parameter is required")
|
||||
|
||||
if not isinstance(self.config.model, ChatModelBase):
|
||||
raise ValueError("`model` must be an instance of ChatModelBase")
|
||||
|
||||
self.agentscope_model = self.config.model
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
model_response: ChatResponse,
|
||||
has_tool: bool,
|
||||
) -> str | dict:
|
||||
"""Parse the model response into a string or
|
||||
a dict to follow the mem0 library's format.
|
||||
|
||||
Args:
|
||||
model_response (`ChatResponse`): The response from the model.
|
||||
has_tool (`bool`): Whether there are tool calls in the response.
|
||||
|
||||
Returns:
|
||||
`str | dict`:
|
||||
The parsed response. If has_tool is True, return a dict
|
||||
with "content" and "tool_calls" keys. Otherwise, return
|
||||
a string.
|
||||
"""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_parts = []
|
||||
for block in model_response.content:
|
||||
# Handle TextBlock
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(str(block.get("text", "")))
|
||||
|
||||
# Handle ThinkingBlock
|
||||
elif isinstance(block, dict) and block.get("type") == "thinking":
|
||||
thinking_parts.append(
|
||||
f"[Thinking: {block.get('thinking', '')}]",
|
||||
)
|
||||
# Handle ToolUseBlock
|
||||
elif isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tool_name = block.get("name")
|
||||
tool_input = block.get("input", {})
|
||||
tool_parts.append(
|
||||
{
|
||||
"name": tool_name,
|
||||
"arguments": tool_input,
|
||||
},
|
||||
)
|
||||
text_part = thinking_parts + text_parts
|
||||
if has_tool:
|
||||
# If there are tool calls, return the content and tool calls
|
||||
return {
|
||||
"content": "\n".join(text_part) if len(text_part) > 0 else "",
|
||||
"tool_calls": tool_parts,
|
||||
}
|
||||
else:
|
||||
return "\n".join(text_part) if len(text_part) > 0 else ""
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format: Any | None = None,
|
||||
tools: List[Dict] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
) -> str | dict:
|
||||
"""Generate a response based on the given messages using agentscope.
|
||||
|
||||
Args:
|
||||
messages (`List[Dict[str, str]]`):
|
||||
List of message dicts containing 'role' and 'content'.
|
||||
response_format (`Any | None`, optional):
|
||||
Format of the response. Not used in AgentScope.
|
||||
tools (`List[Dict] | None`, optional):
|
||||
List of tools that the model can call. Not used in AgentScope.
|
||||
tool_choice (`str`, optional):
|
||||
Tool choice method. Not used in AgentScope.
|
||||
|
||||
Returns:
|
||||
`str | dict`:
|
||||
The generated response.
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
try:
|
||||
# Convert the messages to AgentScope's format
|
||||
agentscope_messages = []
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role in ["system", "user", "assistant"]:
|
||||
agentscope_messages.append(
|
||||
{"role": role, "content": content},
|
||||
)
|
||||
|
||||
if not agentscope_messages:
|
||||
raise ValueError(
|
||||
"No valid messages found in the messages list",
|
||||
)
|
||||
|
||||
# Use the agentscope model to generate response (async call)
|
||||
async def _async_call() -> ChatResponse:
|
||||
# TODO: handle the streaming response or forbidden streaming
|
||||
# mode
|
||||
return await self.agentscope_model( # type: ignore
|
||||
agentscope_messages,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Run in the persistent event loop
|
||||
# This ensures the model client (e.g., Ollama AsyncClient)
|
||||
# always runs in the same event loop, avoiding binding issues
|
||||
response = _run_async_in_persistent_loop(
|
||||
_async_call(),
|
||||
)
|
||||
has_tool = tools is not None
|
||||
|
||||
# Extract text from the response content blocks
|
||||
if not response.content:
|
||||
if has_tool:
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": [],
|
||||
}
|
||||
else:
|
||||
return ""
|
||||
|
||||
return self._parse_response(response, has_tool)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error generating response using agentscope model: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
class AgentScopeEmbedding(EmbeddingBase):
|
||||
"""Wrapper for the AgentScope Embedding model.
|
||||
|
||||
This class is a wrapper for the AgentScope Embedding model. It is used
|
||||
to generate embeddings using the AgentScope Embedding model in mem0.
|
||||
"""
|
||||
|
||||
def __init__(self, config: BaseEmbedderConfig | None = None):
|
||||
"""Initialize the AgentScopeEmbedding wrapper.
|
||||
|
||||
Args:
|
||||
config (`BaseEmbedderConfig | None`, optional):
|
||||
Configuration object for the embedder. Default is None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None:
|
||||
raise ValueError("`model` parameter is required")
|
||||
|
||||
if not isinstance(self.config.model, EmbeddingModelBase):
|
||||
raise ValueError(
|
||||
"`model` must be an instance of EmbeddingModelBase",
|
||||
)
|
||||
|
||||
self.agentscope_model = self.config.model
|
||||
|
||||
def embed(
|
||||
self,
|
||||
text: str | List[str],
|
||||
memory_action: Literal[ # pylint: disable=unused-argument
|
||||
"add",
|
||||
"search",
|
||||
"update",
|
||||
]
|
||||
| None = None,
|
||||
) -> List[float]:
|
||||
"""Get the embedding for the given text using AgentScope.
|
||||
|
||||
Args:
|
||||
text (`str | List[str]`):
|
||||
The text to embed.
|
||||
memory_action (`Literal["add", "search", "update"] | None`, \
|
||||
optional):
|
||||
The type of embedding to use. Must be one of "add", "search",
|
||||
or "update". Defaults to None.
|
||||
|
||||
Returns:
|
||||
`List[float]`:
|
||||
The embedding vector.
|
||||
"""
|
||||
try:
|
||||
# Convert single text to list for AgentScope embedding model
|
||||
text_list = [text] if isinstance(text, str) else text
|
||||
|
||||
# Use the agentscope model to generate embedding (async call)
|
||||
async def _async_call() -> Any:
|
||||
response = await self.agentscope_model(text_list)
|
||||
return response
|
||||
|
||||
# Run in the persistent event loop
|
||||
# This ensures the model client (e.g., Ollama AsyncClient)
|
||||
# always runs in the same event loop, avoiding binding issues
|
||||
response = _run_async_in_persistent_loop(
|
||||
_async_call(),
|
||||
)
|
||||
|
||||
# Extract the embedding vector from the first Embedding object
|
||||
# response.embeddings is a list of Embedding objects
|
||||
# Each Embedding object has an 'embedding' attribute containing
|
||||
# the vector
|
||||
embedding = response.embeddings[0]
|
||||
|
||||
if embedding is None:
|
||||
raise ValueError("Failed to extract embedding from response")
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error generating embedding using agentscope model: {str(e)}",
|
||||
) from e
|
||||
12
src/agentscope/memory/_long_term_memory/_reme/__init__.py
Normal file
12
src/agentscope/memory/_long_term_memory/_reme/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The reme memory module."""
|
||||
|
||||
from ._reme_personal_long_term_memory import ReMePersonalLongTermMemory
|
||||
from ._reme_task_long_term_memory import ReMeTaskLongTermMemory
|
||||
from ._reme_tool_long_term_memory import ReMeToolLongTermMemory
|
||||
|
||||
__all__ = [
|
||||
"ReMePersonalLongTermMemory",
|
||||
"ReMeTaskLongTermMemory",
|
||||
"ReMeToolLongTermMemory",
|
||||
]
|
||||
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Base long-term memory implementation using ReMe library.
|
||||
|
||||
This module provides a base class for long-term memory implementations
|
||||
that integrate with the ReMe library. ReMe enables agents to maintain
|
||||
persistent, searchable memories across sessions and contexts.
|
||||
|
||||
The module handles the integration between AgentScope's memory system and
|
||||
the ReMe library, including:
|
||||
- Model configuration and API credential management
|
||||
- Context lifecycle management (async context managers)
|
||||
- Graceful handling of missing dependencies
|
||||
- Error handling with helpful installation instructions
|
||||
|
||||
Key Features:
|
||||
- Supports both DashScope and OpenAI model providers
|
||||
- Automatic extraction of API credentials and endpoints
|
||||
- Flexible configuration via config files or kwargs
|
||||
- Safe fallback behavior when reme_ai is not installed
|
||||
|
||||
Dependencies:
|
||||
The ReMe library is an optional dependency that must be installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install reme-ai
|
||||
For more information, visit: https://github.com/modelscope/reMe
|
||||
|
||||
Subclasses:
|
||||
This base class is extended by specific memory type implementations:
|
||||
- ReMeToolLongTermMemory: For tool execution patterns and guidelines
|
||||
- ReMeTaskLongTermMemory: For task execution experiences and learnings
|
||||
- ReMePersonalLongTermMemory: For user preferences and personal information
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentscope.models import OpenAIChatModel
|
||||
from agentscope.embedding import OpenAITextEmbedding
|
||||
from agentscope.memory._reme import ReMeToolLongTermMemory
|
||||
|
||||
# Initialize models
|
||||
model = OpenAIChatModel(model_name="gpt-4", api_key="...")
|
||||
embedding = OpenAITextEmbedding(
|
||||
model_name="text-embedding-3-small", api_key="...")
|
||||
|
||||
# Create memory instance
|
||||
memory = ReMeToolLongTermMemory(
|
||||
agent_name="my_agent",
|
||||
user_name="user_123",
|
||||
model=model,
|
||||
embedding_model=embedding
|
||||
)
|
||||
|
||||
# Use memory in async context
|
||||
async with memory:
|
||||
# Record tool execution
|
||||
await memory.record_to_memory(
|
||||
thinking="This tool worked well for data processing",
|
||||
content=['{"tool_name": "process_data", "success": true, ...}']
|
||||
)
|
||||
|
||||
# Retrieve tool guidelines
|
||||
result = await memory.retrieve_from_memory(
|
||||
keywords=["process_data"]
|
||||
)
|
||||
|
||||
"""
|
||||
from abc import ABCMeta
|
||||
from typing import Any
|
||||
|
||||
from .._long_term_memory_base import LongTermMemoryBase
|
||||
from ....embedding import DashScopeTextEmbedding, OpenAITextEmbedding
|
||||
from ....model import DashScopeChatModel, OpenAIChatModel
|
||||
|
||||
|
||||
class ReMeLongTermMemoryBase(LongTermMemoryBase, metaclass=ABCMeta):
|
||||
"""Base class for ReMe-based long-term memory implementations.
|
||||
|
||||
This class provides the foundation for integrating AgentScope with the ReMe
|
||||
library, enabling agents to maintain and retrieve long-term memories across
|
||||
different contexts.
|
||||
|
||||
The ReMe library must be installed separately:
|
||||
pip install reme-ai
|
||||
|
||||
If the library is not installed, a warning will be issued during
|
||||
initialization,
|
||||
and runtime errors with installation instructions will be raised
|
||||
when attempting
|
||||
to use memory operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
user_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
model: DashScopeChatModel | OpenAIChatModel | None = None,
|
||||
embedding_model: (
|
||||
DashScopeTextEmbedding | OpenAITextEmbedding | None
|
||||
) = None,
|
||||
reme_config_path: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the ReMe-based long-term memory.
|
||||
|
||||
This constructor sets up the connection to the ReMe
|
||||
library and configures
|
||||
the necessary models for memory operations. The ReMe app
|
||||
will be initialized
|
||||
with the provided model configurations.
|
||||
|
||||
Args:
|
||||
agent_name (`str | None`, optional):
|
||||
Name identifier for the agent. Used for organizing
|
||||
memories by agent.
|
||||
user_name (`str | None`, optional):
|
||||
Unique identifier for the user or workspace. This maps
|
||||
to workspace_id in ReMe and helps isolate memories across
|
||||
different users/workspaces.
|
||||
run_name (`str | None`, optional):
|
||||
Name identifier for the current execution run or session.
|
||||
model (`DashScopeChatModel | OpenAIChatModel | None`, optional):
|
||||
The chat model to use for memory operations. The model's
|
||||
API credentials and endpoint will be extracted and
|
||||
passed to ReMe.
|
||||
embedding_model (`DashScopeTextEmbedding | OpenAITextEmbedding | \
|
||||
None`, optional):
|
||||
The embedding model to use for semantic memory retrieval.
|
||||
The model's API credentials and endpoint will be
|
||||
extracted and passed to ReMe.
|
||||
reme_config_path (`str | None`, optional):
|
||||
Path to a custom ReMe configuration file. If not provided, ReMe
|
||||
will use its default configuration.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments to pass to the
|
||||
ReMeApp constructor.
|
||||
These can include custom ReMe configuration parameters.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the provided model is not a DashScopeChatModel or
|
||||
OpenAIChatModel, or if the embedding_model is not a
|
||||
DashScopeTextEmbedding or OpenAITextEmbedding.
|
||||
|
||||
Note:
|
||||
If the reme_ai library is not installed, a warning will be
|
||||
issued and self.app will be set to None. Subsequent memory
|
||||
operations will raise RuntimeError with installation
|
||||
instructions.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from agentscope.models import OpenAIChatModel
|
||||
from agentscope.embedding import OpenAITextEmbedding
|
||||
from agentscope.memory._reme import ReMeToolLongTermMemory
|
||||
|
||||
# Initialize models
|
||||
model = OpenAIChatModel(
|
||||
model_name="gpt-4",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
embedding = OpenAITextEmbedding(
|
||||
model_name="text-embedding-3-small",
|
||||
api_key="your-api-key"
|
||||
)
|
||||
|
||||
# Create memory instance
|
||||
memory = ReMeToolLongTermMemory(
|
||||
agent_name="my_agent",
|
||||
user_name="user_123",
|
||||
run_name="session_001",
|
||||
model=model,
|
||||
embedding_model=embedding
|
||||
)
|
||||
|
||||
# Use with async context manager
|
||||
async with memory:
|
||||
# Memory operations...
|
||||
pass
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Store agent and workspace identifiers
|
||||
self.agent_name = agent_name
|
||||
# Maps to ReMe's workspace_id concept
|
||||
self.workspace_id = user_name
|
||||
self.run_name = run_name
|
||||
|
||||
# Build configuration arguments for ReMeApp
|
||||
# These will be passed as command-line style config overrides
|
||||
config_args = []
|
||||
|
||||
# Extract LLM API credentials based on model type
|
||||
# DashScope uses a fixed endpoint, OpenAI can have custom base_url
|
||||
if isinstance(model, DashScopeChatModel):
|
||||
llm_api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
llm_api_key = model.api_key
|
||||
|
||||
elif isinstance(model, OpenAIChatModel):
|
||||
llm_api_base = str(getattr(model.client, "base_url", None))
|
||||
llm_api_key = str(getattr(model.client, "api_key", None))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"model must be a DashScopeChatModel or "
|
||||
f"OpenAIChatModel instance. "
|
||||
f"Got {type(model).__name__} instead.",
|
||||
)
|
||||
|
||||
# Extract model name and add to config if provided
|
||||
llm_model_name = model.model_name
|
||||
|
||||
if llm_model_name:
|
||||
config_args.append(f"llm.default.model_name={llm_model_name}")
|
||||
|
||||
# Extract embedding model API credentials based on type
|
||||
# Similar to LLM, DashScope uses fixed endpoint,
|
||||
# OpenAI can be customized
|
||||
if isinstance(embedding_model, DashScopeTextEmbedding):
|
||||
embedding_api_base = (
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
embedding_api_key = embedding_model.api_key
|
||||
|
||||
elif isinstance(embedding_model, OpenAITextEmbedding):
|
||||
embedding_api_base = str(
|
||||
getattr(embedding_model.client, "base_url", None),
|
||||
)
|
||||
embedding_api_key = str(
|
||||
getattr(embedding_model.client, "api_key", None),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"embedding_model must be a DashScopeTextEmbedding or "
|
||||
"OpenAITextEmbedding instance. "
|
||||
f"Got {type(embedding_model).__name__} instead.",
|
||||
)
|
||||
|
||||
# Extract embedding model name and add to config if provided
|
||||
embedding_model_name = embedding_model.model_name
|
||||
|
||||
if embedding_model_name:
|
||||
config_args.append(
|
||||
f"embedding_model.default.model_name={embedding_model_name}",
|
||||
)
|
||||
|
||||
dimensions = embedding_model.dimensions
|
||||
config_args.append(
|
||||
f'embedding_model.default.params={{"dimensions": {dimensions}}}',
|
||||
)
|
||||
|
||||
# Attempt to import and initialize ReMe
|
||||
# If import fails, set app to None and issue a warning
|
||||
# This allows the class to be instantiated even without
|
||||
# reme_ai installed
|
||||
try:
|
||||
from reme_ai import ReMeApp
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The 'reme_ai' library is required for ReMe-based "
|
||||
"long-term memory. Please install it by `pip install reme-ai`,"
|
||||
"and visit: https://github.com/modelscope/reMe for more "
|
||||
"information.",
|
||||
) from e
|
||||
|
||||
# Initialize ReMe with extracted configurations
|
||||
self.app = ReMeApp(
|
||||
*config_args, # Config overrides as positional args
|
||||
llm_api_key=llm_api_key,
|
||||
llm_api_base=llm_api_base,
|
||||
embedding_api_key=embedding_api_key,
|
||||
embedding_api_base=embedding_api_base,
|
||||
# Optional custom config file
|
||||
config_path=reme_config_path,
|
||||
# Additional ReMe-specific configurations
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Track if the app context is active (started via __aenter__)
|
||||
self._app_started = False
|
||||
|
||||
async def __aenter__(self) -> "ReMeLongTermMemoryBase":
|
||||
"""Async context manager entry point.
|
||||
|
||||
This method is called when entering an async context
|
||||
(using 'async with'). It initializes the ReMe app context if
|
||||
available, enabling memory operations within the context block.
|
||||
|
||||
Returns:
|
||||
`ReMeLongTermMemoryBase`:
|
||||
The memory instance itself, allowing it to be used in
|
||||
the context.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
memory = ReMeToolLongTermMemory(
|
||||
agent_name="my_agent",
|
||||
model=model,
|
||||
embedding_model=embedding
|
||||
)
|
||||
|
||||
async with memory:
|
||||
# Memory operations can be performed here
|
||||
await memory.record_to_memory(
|
||||
thinking="Recording tool usage",
|
||||
content=[...]
|
||||
)
|
||||
|
||||
"""
|
||||
if self.app is not None:
|
||||
await self.app.__aenter__()
|
||||
self._app_started = True
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Any = None,
|
||||
exc_val: Any = None,
|
||||
exc_tb: Any = None,
|
||||
) -> None:
|
||||
"""Async context manager exit point.
|
||||
|
||||
This method is called when exiting an async context (at the end
|
||||
of 'async with' block or when an exception occurs). It properly
|
||||
cleans up the ReMe app context and resources.
|
||||
|
||||
Args:
|
||||
exc_type (`Any`):
|
||||
The type of exception that occurred, if any. None if no
|
||||
exception.
|
||||
exc_val (`Any`):
|
||||
The exception instance that occurred, if any. None if no
|
||||
exception.
|
||||
exc_tb (`Any`):
|
||||
The traceback object for the exception, if any. None if
|
||||
no exception.
|
||||
|
||||
.. note:: This method will gracefully handle the case where self.app
|
||||
is None (reme_ai not installed) by skipping the cleanup but still
|
||||
marking the app as stopped. It will also always set _app_started
|
||||
to False, ensuring the memory state is properly reset.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
async with memory:
|
||||
try:
|
||||
# Memory operations
|
||||
await memory.record_to_memory(...)
|
||||
except Exception as e:
|
||||
# __aexit__ will be called even if an exception occurs
|
||||
print(f"Error: {e}")
|
||||
# __aexit__ has been called and resources are cleaned up
|
||||
|
||||
"""
|
||||
if self.app is not None:
|
||||
await self.app.__aexit__(exc_type, exc_val, exc_tb)
|
||||
self._app_started = False
|
||||
@@ -0,0 +1,414 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Personal memory implementation using ReMe library.
|
||||
|
||||
This module provides a personal memory implementation that integrates
|
||||
with the ReMe library to provide persistent personal memory storage and
|
||||
retrieval capabilities for AgentScope agents.
|
||||
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
|
||||
from ...._logging import logger
|
||||
from ....message import Msg, TextBlock
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
class ReMePersonalLongTermMemory(ReMeLongTermMemoryBase):
|
||||
"""Personal memory implementation using ReMe library."""
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Record important user information to long-term memory.
|
||||
|
||||
Record important user information to long-term memory for future
|
||||
reference.
|
||||
|
||||
Use this function to save user's personal information,
|
||||
preferences, habits, and facts that you may need in future
|
||||
conversations. This enables you to provide personalized and
|
||||
contextually relevant responses.
|
||||
|
||||
When to record:
|
||||
|
||||
- User shares personal preferences (e.g., "I prefer homestays
|
||||
when traveling")
|
||||
- User mentions habits or routines (e.g., "I start work at 9 AM")
|
||||
- User states likes/dislikes (e.g., "I enjoy drinking green tea")
|
||||
- User provides personal facts (e.g., "I work as a software
|
||||
engineer")
|
||||
|
||||
What to record: Be specific and structured. Include who, when,
|
||||
where, what, why, and how when relevant.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your reasoning about why this information is worth
|
||||
recording and how it might be useful later.
|
||||
content (`list[str]`):
|
||||
List of specific facts to remember. Each string should be
|
||||
a clear, standalone piece of information. Examples:
|
||||
["User prefers homestays in Hangzhou", "User likes
|
||||
visiting West Lake in the morning"].
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Confirmation message indicating successful memory
|
||||
recording.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMePersonalMemory] Entering record_to_memory - "
|
||||
"thinking: %s, content: %s, kwargs: %s",
|
||||
thinking,
|
||||
content,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare messages for personal memory recording
|
||||
messages = []
|
||||
|
||||
# Add thinking as a user message if provided
|
||||
if thinking:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": thinking,
|
||||
},
|
||||
)
|
||||
|
||||
# Add content items as user messages
|
||||
for item in content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": item,
|
||||
},
|
||||
)
|
||||
# Add a simple assistant acknowledgment
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
"I understand and will remember this "
|
||||
"information."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
result = await self.app.async_execute(
|
||||
name="summary_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract metadata about stored memories if available
|
||||
metadata = result.get("metadata", {})
|
||||
memory_list = metadata.get("memory_list", [])
|
||||
|
||||
if memory_list:
|
||||
summary_text = (
|
||||
f"Successfully recorded {len(memory_list)} "
|
||||
f"memory/memories to personal memory."
|
||||
)
|
||||
else:
|
||||
summary_text = "Memory recording completed."
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=summary_text,
|
||||
),
|
||||
],
|
||||
metadata={"result": result},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error recording memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Search and retrieve relevant information from long-term memory.
|
||||
|
||||
.. note:: You should call this function BEFORE answering
|
||||
questions about the user's preferences, past information, or
|
||||
personal details. This ensures you provide accurate information
|
||||
based on stored memories rather than guessing.
|
||||
|
||||
Use this when:
|
||||
|
||||
- User asks "what do I like?", "what are my preferences?",
|
||||
"what do you know about me?"
|
||||
- User asks about their past behaviors, habits, or stated
|
||||
preferences
|
||||
- User refers to information they shared in previous
|
||||
conversations
|
||||
- You need to personalize responses based on user's history
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
Keywords to search for in memory. Be specific and use
|
||||
multiple keywords for better results. Examples:
|
||||
["travel preferences", "Hangzhou"], ["work habits",
|
||||
"morning routine"], ["food preferences", "tea"].
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 3.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the retrieval operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Retrieved memories matching the keywords. If no memories
|
||||
found, you'll receive a message indicating that.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMePersonalMemory] Entering retrieve_from_memory - "
|
||||
"keywords: %s, kwargs: %s",
|
||||
keywords,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
# Search for each keyword
|
||||
for keyword in keywords:
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=keyword,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract the answer from the result
|
||||
answer = result.get("answer", "")
|
||||
if answer:
|
||||
results.append(f"Keyword '{keyword}':\n{answer}")
|
||||
|
||||
# Combine all results
|
||||
if results:
|
||||
combined_text = "\n\n".join(results)
|
||||
else:
|
||||
combined_text = "No memories found for the given keywords."
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=combined_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record the content to the long-term memory.
|
||||
|
||||
This method converts AgentScope messages to ReMe's format and
|
||||
records them using the personal memory flow.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the mem0 recording.
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not msg_list:
|
||||
return
|
||||
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert AgentScope messages to ReMe format
|
||||
messages = []
|
||||
for msg in msg_list:
|
||||
# Extract content as string
|
||||
if isinstance(msg.content, str):
|
||||
content_str = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
# Join content blocks into a single string
|
||||
content_parts = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
content_str = "\n".join(content_parts)
|
||||
else:
|
||||
content_str = str(msg.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": content_str,
|
||||
},
|
||||
)
|
||||
|
||||
await self.app.async_execute(
|
||||
name="summary_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise to maintain compatibility
|
||||
logger.exception("Error recording messages to memory: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error recording messages to memory: {str(e)}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve the content from the long-term memory.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message to search for in the memory, which should be
|
||||
specific and concise, e.g. the person's name, the date, the
|
||||
location, etc.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. If the
|
||||
message is a list of messages, the limit applies to each
|
||||
message. If the message is a single message, the limit is the
|
||||
total number of memories to retrieve for that message. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved memory as a string.
|
||||
"""
|
||||
if msg is None:
|
||||
return ""
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Only use the last message's content for retrieval
|
||||
last_msg = msg[-1]
|
||||
query = ""
|
||||
|
||||
if isinstance(last_msg.content, str):
|
||||
query = last_msg.content
|
||||
elif isinstance(last_msg.content, list):
|
||||
# Extract text from content blocks
|
||||
content_parts = []
|
||||
for block in last_msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
query = "\n".join(content_parts)
|
||||
|
||||
if not query:
|
||||
return ""
|
||||
|
||||
# Retrieve using the query from the last message
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_personal_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=query,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return result.get("answer", "")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving memory: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error retrieving memory: {str(e)}")
|
||||
return ""
|
||||
@@ -0,0 +1,436 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task memory implementation using ReMe library.
|
||||
|
||||
This module provides a task memory implementation that integrates
|
||||
with the ReMe library to learn from execution trajectories and
|
||||
retrieve relevant task experiences.
|
||||
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
|
||||
from ...._logging import logger
|
||||
from ....message import Msg, TextBlock
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
class ReMeTaskLongTermMemory(ReMeLongTermMemoryBase):
|
||||
"""Task memory implementation using ReMe library.
|
||||
|
||||
Task memory learns from execution trajectories and provides
|
||||
retrieval of relevant task experiences.
|
||||
|
||||
"""
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Record task execution experiences and learnings.
|
||||
|
||||
Record task execution experiences and learnings to long-term
|
||||
memory.
|
||||
|
||||
Use this function to save valuable task-related knowledge that
|
||||
can help with future similar tasks. This enables learning from
|
||||
experience and improving over time.
|
||||
|
||||
When to record:
|
||||
|
||||
- After solving technical problems or completing tasks
|
||||
- When discovering useful techniques or approaches
|
||||
- After implementing solutions with specific steps
|
||||
- When learning best practices or important lessons
|
||||
|
||||
What to record: Be detailed and actionable. Include:
|
||||
|
||||
- Task description and context
|
||||
- Step-by-step execution details
|
||||
- Specific techniques and methods used
|
||||
- Results, outcomes, and effectiveness
|
||||
- Lessons learned and considerations
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your reasoning about why this task experience is valuable
|
||||
and what makes it worth remembering for future reference.
|
||||
content (`list[str]`):
|
||||
List of specific task insights to remember. Each string
|
||||
should be a clear, actionable piece of information.
|
||||
Examples: ["Add indexes on WHERE clause columns to speed
|
||||
up queries", "Use EXPLAIN ANALYZE to identify missing
|
||||
indexes"].
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments. Can include 'score' (float)
|
||||
to indicate the quality/success of this approach
|
||||
(default: 1.0).
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Confirmation message indicating successful memory
|
||||
recording.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeTaskMemory] Entering record_to_memory - "
|
||||
"thinking: %s, content: %s, kwargs: %s",
|
||||
thinking,
|
||||
content,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare messages for task memory recording
|
||||
messages = []
|
||||
|
||||
# Add thinking as a user message if provided
|
||||
if thinking:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": thinking,
|
||||
},
|
||||
)
|
||||
|
||||
# Add content items as user-assistant pairs
|
||||
for item in content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": item,
|
||||
},
|
||||
)
|
||||
# Add a simple assistant acknowledgment
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Task information recorded.",
|
||||
},
|
||||
)
|
||||
|
||||
result = await self.app.async_execute(
|
||||
name="summary_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
"score": kwargs.pop("score", 1.0),
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract metadata if available
|
||||
summary_text = (
|
||||
f"Successfully recorded {len(content)} task memory/memories."
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=summary_text,
|
||||
),
|
||||
],
|
||||
metadata={"result": result},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error recording task memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording task memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Search and retrieve relevant task experiences.
|
||||
|
||||
Search and retrieve relevant task experiences from long-term
|
||||
memory.
|
||||
|
||||
IMPORTANT: You should call this function BEFORE attempting to
|
||||
solve problems or answer technical questions. This ensures you
|
||||
leverage experiences and proven solutions rather than
|
||||
starting from scratch.
|
||||
|
||||
Use this when:
|
||||
- Asked to solve a technical problem or implement a solution
|
||||
- Asked for recommendations, best practices, or approaches
|
||||
- Asked "what do you know about...?" or "have you seen this
|
||||
before?"
|
||||
- Dealing with tasks that may be similar to experiences
|
||||
- Need to recall specific techniques or methods
|
||||
|
||||
Benefits of retrieving first:
|
||||
- Learn from past successes and mistakes
|
||||
- Provide more accurate, battle-tested solutions
|
||||
- Avoid reinventing the wheel
|
||||
- Give consistent, informed recommendations
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
Keywords describing the task or problem domain. Be
|
||||
specific and use technical terms. Examples:
|
||||
["database optimization", "slow queries"], ["API design",
|
||||
"rate limiting"], ["code refactoring", "Python"].
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments. Can include 'top_k' (int)
|
||||
to specify number of experiences to retrieve
|
||||
(default: 3).
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Retrieved task experiences and learnings. If no relevant
|
||||
experiences found, you'll receive a message indicating
|
||||
that.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeTaskMemory] Entering retrieve_from_memory - "
|
||||
"keywords: %s, kwargs: %s",
|
||||
keywords,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
results = []
|
||||
|
||||
# Search for each keyword
|
||||
for keyword in keywords:
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=keyword,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract the answer from the result
|
||||
answer = result.get("answer", "")
|
||||
if answer:
|
||||
results.append(f"Keyword '{keyword}':\n{answer}")
|
||||
|
||||
# Combine all results
|
||||
if results:
|
||||
combined_text = "\n\n".join(results)
|
||||
else:
|
||||
combined_text = (
|
||||
"No task experiences found for the given keywords."
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=combined_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving task memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving task memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record the content to the task memory.
|
||||
|
||||
This method converts AgentScope messages to ReMe's format and
|
||||
records them as a task execution trajectory.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording.
|
||||
Can include 'score' (float) for trajectory scoring
|
||||
(default: 1.0).
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not msg_list:
|
||||
return
|
||||
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert AgentScope messages to ReMe format
|
||||
messages = []
|
||||
for msg in msg_list:
|
||||
# Extract content as string
|
||||
if isinstance(msg.content, str):
|
||||
content_str = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
# Join content blocks into a single string
|
||||
content_parts = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
content_str = "\n".join(content_parts)
|
||||
else:
|
||||
content_str = str(msg.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": content_str,
|
||||
},
|
||||
)
|
||||
|
||||
# Extract score from kwargs if provided, default to 1.0
|
||||
score = kwargs.pop("score", 1.0)
|
||||
|
||||
await self.app.async_execute(
|
||||
name="summary_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
trajectories=[
|
||||
{
|
||||
"messages": messages,
|
||||
"score": score,
|
||||
},
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise to maintain compatibility
|
||||
logger.exception(
|
||||
"Error recording messages to task memory: %s",
|
||||
str(e),
|
||||
)
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Error recording messages to task memory: {str(e)}",
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve relevant task experiences from memory.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message to search for relevant task experiences.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. If the
|
||||
message is a list of messages, the limit applies to each
|
||||
message. If the message is a single message, the limit is the
|
||||
total number of memories to retrieve for that message. Defaults
|
||||
to 3.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved task experiences as a string.
|
||||
"""
|
||||
if msg is None:
|
||||
return ""
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Only use the last message's content for retrieval
|
||||
last_msg = msg[-1]
|
||||
query = ""
|
||||
|
||||
if isinstance(last_msg.content, str):
|
||||
query = last_msg.content
|
||||
elif isinstance(last_msg.content, list):
|
||||
# Extract text from content blocks
|
||||
content_parts = []
|
||||
for block in last_msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
elif isinstance(block, dict) and "thinking" in block:
|
||||
content_parts.append(block["thinking"])
|
||||
query = "\n".join(content_parts)
|
||||
|
||||
if not query:
|
||||
return ""
|
||||
|
||||
# Retrieve using the query from the last message
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_task_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
query=query,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return result.get("answer", "")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving task memory: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error retrieving task memory: {str(e)}")
|
||||
return ""
|
||||
@@ -0,0 +1,545 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tool memory implementation using ReMe library.
|
||||
|
||||
This module provides a tool memory implementation that integrates
|
||||
with the ReMe library to record tool execution results and retrieve
|
||||
tool usage guidelines.
|
||||
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ._reme_long_term_memory_base import ReMeLongTermMemoryBase
|
||||
from ...._logging import logger
|
||||
from ....message import Msg, TextBlock
|
||||
from ....tool import ToolResponse
|
||||
|
||||
|
||||
class ReMeToolLongTermMemory(ReMeLongTermMemoryBase):
|
||||
"""Tool memory implementation using ReMe library.
|
||||
|
||||
Tool memory records tool execution results and generates usage
|
||||
guidelines from the execution history.
|
||||
|
||||
"""
|
||||
|
||||
async def record_to_memory(
|
||||
self,
|
||||
thinking: str,
|
||||
content: list[str],
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Record tool execution results to build tool usage patterns.
|
||||
|
||||
Record tool execution results to build a knowledge base of tool
|
||||
usage patterns.
|
||||
|
||||
Use this function after successfully using tools to capture
|
||||
execution details, results, and performance metrics. Over time,
|
||||
this builds comprehensive usage guidelines and best practices
|
||||
for each tool.
|
||||
|
||||
When to record:
|
||||
|
||||
- After successfully executing any tool
|
||||
- After tool failures (to learn what doesn't work)
|
||||
- When discovering effective parameter combinations
|
||||
- After noteworthy tool usage patterns
|
||||
|
||||
What to record: Each tool execution should include complete
|
||||
execution details.
|
||||
|
||||
Args:
|
||||
thinking (`str`):
|
||||
Your reasoning about why this tool execution is worth
|
||||
recording. Mention what worked well, what could be
|
||||
improved, or lessons learned.
|
||||
content (`list[str]`):
|
||||
List of JSON strings, each representing a tool execution.
|
||||
Each JSON must have these fields:
|
||||
- create_time: Timestamp in format "YYYY-MM-DD HH:MM:SS"
|
||||
- tool_name: Name of the tool executed
|
||||
- input: Input parameters as a dict
|
||||
- output: Tool's output as a string
|
||||
- token_cost: Token cost (integer)
|
||||
- success: Whether execution succeeded (boolean)
|
||||
- time_cost: Execution time in seconds (float)
|
||||
|
||||
Example: '{"create_time": "2024-01-01 10:00:00",
|
||||
"tool_name": "search", "input": {"query": "Python"},
|
||||
"output": "Found 10 results", "token_cost": 100,
|
||||
"success": true, "time_cost": 1.2}'
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Confirmation message with number of executions recorded
|
||||
and guidelines generated.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeToolMemory] Entering record_to_memory - "
|
||||
"thinking: %s, content: %s, kwargs: %s",
|
||||
thinking,
|
||||
content,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
import json
|
||||
|
||||
# Parse each content item as a tool_call_result
|
||||
tool_call_results = []
|
||||
tool_names_set = set()
|
||||
|
||||
for item in content:
|
||||
try:
|
||||
# Parse JSON string to dict
|
||||
tool_call_result = json.loads(item)
|
||||
tool_call_results.append(tool_call_result)
|
||||
|
||||
# Track tool names for summary
|
||||
if "tool_name" in tool_call_result:
|
||||
tool_names_set.add(tool_call_result["tool_name"])
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Skip invalid JSON items
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Failed to parse tool call result JSON: {item}. "
|
||||
f"Error: {str(e)}",
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool_call_results:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="No valid tool call results to record.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# First, add the tool call results
|
||||
await self.app.async_execute(
|
||||
name="add_tool_call_result",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_call_results=tool_call_results,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Then, summarize the tool memory for the affected tools
|
||||
if tool_names_set:
|
||||
tool_names_list = list(tool_names_set)
|
||||
await self.app.async_execute(
|
||||
name="summary_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names_list,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
num_results = len(tool_call_results)
|
||||
summary_text = (
|
||||
f"Successfully recorded {num_results} tool execution "
|
||||
f"result{'s' if num_results > 1 else ''} and generated "
|
||||
f"usage guidelines."
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=summary_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error recording tool memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error recording tool memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def retrieve_from_memory(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Retrieve usage guidelines and best practices for tools.
|
||||
|
||||
Retrieve usage guidelines and best practices for specific tools.
|
||||
|
||||
.. note:: You should call this function BEFORE using a tool,
|
||||
especially if you're uncertain about its proper usage or want to
|
||||
follow established best practices. This retrieves synthesized
|
||||
guidelines based on past tool executions.
|
||||
|
||||
Use this when:
|
||||
|
||||
- About to use a tool and want to know the best practices
|
||||
- Uncertain about tool parameters or usage patterns
|
||||
- Want to learn from past successful/failed tool executions
|
||||
- User asks "how should I use this tool?" or "what's the best
|
||||
way to..."
|
||||
- Need to understand tool performance characteristics or
|
||||
limitations
|
||||
|
||||
Benefits of retrieving first:
|
||||
|
||||
- Learn from accumulated tool usage experience
|
||||
- Avoid common mistakes and pitfalls
|
||||
- Use optimal parameter combinations
|
||||
- Understand tool performance and cost characteristics
|
||||
- Follow established best practices
|
||||
|
||||
Args:
|
||||
keywords (`list[str]`):
|
||||
List of tool names to retrieve guidelines for. Use the
|
||||
exact tool names. Examples: ["search"],
|
||||
["database_query", "cache_get"], ["api_call"].
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for each keyword. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the retrieval operation.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
Retrieved usage guidelines and best practices for the
|
||||
specified tools. If no guidelines exist yet, you'll
|
||||
receive a message indicating that.
|
||||
"""
|
||||
logger.info(
|
||||
"[ReMeToolMemory] Entering retrieve_from_memory - "
|
||||
"keywords: %s, kwargs: %s",
|
||||
keywords,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Join all tool names with comma
|
||||
tool_names = ",".join(keywords)
|
||||
|
||||
# Retrieve tool guidelines for all tools at once
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract the answer from the result
|
||||
answer = result.get("answer", "")
|
||||
if answer:
|
||||
combined_text = answer
|
||||
else:
|
||||
combined_text = f"No tool guidelines found for: {tool_names}"
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=combined_text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving tool memory: %s", str(e))
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error retrieving tool memory: {str(e)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _extract_content_from_messages(self, msg_list: list[Msg]) -> list[str]:
|
||||
"""Extract content strings from messages.
|
||||
|
||||
Args:
|
||||
msg_list (`list[Msg]`):
|
||||
List of messages to extract content from.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
List of extracted content strings.
|
||||
"""
|
||||
content_list = []
|
||||
for msg in msg_list:
|
||||
if isinstance(msg.content, str):
|
||||
content_list.append(msg.content)
|
||||
elif isinstance(msg.content, list):
|
||||
content_list.extend(
|
||||
self._extract_text_from_blocks(msg.content),
|
||||
)
|
||||
return content_list
|
||||
|
||||
def _extract_text_from_blocks(self, blocks: list) -> list[str]:
|
||||
"""Extract text from content blocks.
|
||||
|
||||
Args:
|
||||
blocks (`list`):
|
||||
List of content blocks.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
List of extracted text strings.
|
||||
"""
|
||||
texts = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
texts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
texts.append(block)
|
||||
return texts
|
||||
|
||||
def _parse_tool_call_results(
|
||||
self,
|
||||
content_list: list[str],
|
||||
) -> tuple[list[dict], set[str]]:
|
||||
"""Parse JSON content strings into tool call results.
|
||||
|
||||
Args:
|
||||
content_list (`list[str]`):
|
||||
List of JSON strings to parse.
|
||||
|
||||
Returns:
|
||||
`tuple[list[dict], set[str]]`:
|
||||
Tuple of (tool_call_results, tool_names_set).
|
||||
"""
|
||||
import json
|
||||
import warnings
|
||||
|
||||
tool_call_results = []
|
||||
tool_names_set = set()
|
||||
|
||||
for item in content_list:
|
||||
try:
|
||||
tool_call_result = json.loads(item)
|
||||
tool_call_results.append(tool_call_result)
|
||||
if "tool_name" in tool_call_result:
|
||||
tool_names_set.add(tool_call_result["tool_name"])
|
||||
except json.JSONDecodeError as e:
|
||||
warnings.warn(
|
||||
f"Failed to parse tool call result JSON: {item}. "
|
||||
f"Error: {str(e)}",
|
||||
)
|
||||
|
||||
return tool_call_results, tool_names_set
|
||||
|
||||
async def record(
|
||||
self,
|
||||
msgs: list[Msg | None],
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Record the content to the tool memory.
|
||||
|
||||
This method extracts content from messages and treats them as
|
||||
JSON strings representing tool_call_results, similar to
|
||||
record_to_memory.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg | None]`):
|
||||
The messages to record to memory. Each message's content
|
||||
should be a JSON string or list of JSON strings
|
||||
representing tool_call_results.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments for the recording.
|
||||
"""
|
||||
if isinstance(msgs, Msg):
|
||||
msgs = [msgs]
|
||||
|
||||
# Filter out None
|
||||
msg_list = [_ for _ in msgs if _]
|
||||
if not msg_list:
|
||||
return
|
||||
|
||||
if not all(isinstance(_, Msg) for _ in msg_list):
|
||||
raise TypeError(
|
||||
"The input messages must be a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract content from messages and parse as tool_call_results
|
||||
content_list = self._extract_content_from_messages(msg_list)
|
||||
if not content_list:
|
||||
return
|
||||
|
||||
# Parse each content item as a tool_call_result
|
||||
tool_call_results, tool_names_set = self._parse_tool_call_results(
|
||||
content_list,
|
||||
)
|
||||
if not tool_call_results:
|
||||
return
|
||||
|
||||
# First, add the tool call results
|
||||
await self.app.async_execute(
|
||||
name="add_tool_call_result",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_call_results=tool_call_results,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Then, summarize the tool memory for the affected tools
|
||||
if tool_names_set:
|
||||
tool_names_list = list(tool_names_set)
|
||||
await self.app.async_execute(
|
||||
name="summary_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names_list,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise to maintain compatibility
|
||||
logger.exception(
|
||||
"Error recording tool messages to memory: %s",
|
||||
str(e),
|
||||
)
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Error recording tool messages to memory: {str(e)}",
|
||||
)
|
||||
|
||||
def _extract_tool_names_from_message(self, msg: Msg) -> str:
|
||||
"""Extract tool names from a message.
|
||||
|
||||
Args:
|
||||
msg (`Msg`):
|
||||
Message to extract tool names from.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
Extracted tool names as a string.
|
||||
"""
|
||||
if isinstance(msg.content, str):
|
||||
return msg.content
|
||||
|
||||
if isinstance(msg.content, list):
|
||||
content_parts = []
|
||||
for block in msg.content:
|
||||
if isinstance(block, dict) and "text" in block:
|
||||
content_parts.append(block["text"])
|
||||
return " ".join(content_parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _format_retrieve_result(self, result: Any) -> str:
|
||||
"""Format the retrieve result into a string.
|
||||
|
||||
Args:
|
||||
result (`Any`):
|
||||
Result from the retrieve operation.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
Formatted result string.
|
||||
"""
|
||||
if isinstance(result, dict) and "answer" in result:
|
||||
return result["answer"]
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
return str(result)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
msg: Msg | list[Msg] | None,
|
||||
limit: int = 5,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Retrieve tool guidelines from memory.
|
||||
|
||||
Retrieve tool guidelines from memory based on message content.
|
||||
|
||||
Args:
|
||||
msg (`Msg | list[Msg] | None`):
|
||||
The message containing tool names or queries to
|
||||
retrieve guidelines for.
|
||||
limit (`int`, optional):
|
||||
The maximum number of memories to retrieve per search, i.e.,
|
||||
the number of memories to retrieve for the message. If the
|
||||
message is a list of messages, the limit applies to each
|
||||
message. If the message is a single message, the limit is the
|
||||
total number of memories to retrieve for that message. Defaults
|
||||
to 5.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The retrieved tool guidelines as a string.
|
||||
"""
|
||||
if msg is None:
|
||||
return ""
|
||||
|
||||
if isinstance(msg, Msg):
|
||||
msg = [msg]
|
||||
|
||||
if not isinstance(msg, list) or not all(
|
||||
isinstance(_, Msg) for _ in msg
|
||||
):
|
||||
raise TypeError(
|
||||
"The input message must be a Msg or a list of Msg objects.",
|
||||
)
|
||||
|
||||
if not self._app_started:
|
||||
raise RuntimeError(
|
||||
"ReMeApp context not started. "
|
||||
"Please use 'async with' to initialize the app.",
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool names from the last message
|
||||
last_msg = msg[-1]
|
||||
tool_names = self._extract_tool_names_from_message(last_msg)
|
||||
|
||||
if not tool_names:
|
||||
return ""
|
||||
|
||||
# Retrieve tool guidelines
|
||||
result = await self.app.async_execute(
|
||||
name="retrieve_tool_memory",
|
||||
workspace_id=self.workspace_id,
|
||||
tool_names=tool_names,
|
||||
top_k=limit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._format_retrieve_result(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving tool guidelines: %s", str(e))
|
||||
import warnings
|
||||
|
||||
warnings.warn(f"Error retrieving tool guidelines: {str(e)}")
|
||||
return ""
|
||||
16
src/agentscope/memory/_working_memory/__init__.py
Normal file
16
src/agentscope/memory/_working_memory/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The working memory module in AgentScope, which provides various memory
|
||||
storage implementations. In AgentScope, such module is responsible for
|
||||
storing and managing the short-term memory with specific marks."""
|
||||
|
||||
from ._base import MemoryBase
|
||||
from ._in_memory_memory import InMemoryMemory
|
||||
from ._redis_memory import RedisMemory
|
||||
from ._sqlalchemy_memory import AsyncSQLAlchemyMemory
|
||||
|
||||
__all__ = [
|
||||
"MemoryBase",
|
||||
"InMemoryMemory",
|
||||
"RedisMemory",
|
||||
"AsyncSQLAlchemyMemory",
|
||||
]
|
||||
168
src/agentscope/memory/_working_memory/_base.py
Normal file
168
src/agentscope/memory/_working_memory/_base.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The memory base class."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from ...message import Msg
|
||||
from ...module import StateModule
|
||||
|
||||
|
||||
class MemoryBase(StateModule):
|
||||
"""The base class for memory in agentscope."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the memory base."""
|
||||
super().__init__()
|
||||
|
||||
self._compressed_summary: str = ""
|
||||
|
||||
self.register_state("_compressed_summary")
|
||||
|
||||
async def update_compressed_summary(self, summary: str) -> None:
|
||||
"""Update the compressed summary of the memory.
|
||||
|
||||
Args:
|
||||
summary (`str`):
|
||||
The new compressed summary.
|
||||
"""
|
||||
self._compressed_summary = summary
|
||||
|
||||
@abstractmethod
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message(s) into the memory storage with the given mark
|
||||
(if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg] | None`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the memory by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or a list of strings.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The delete_by_mark method is not implemented in "
|
||||
f"{self.__class__.__name__} class.",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def size(self) -> int:
|
||||
"""Get the number of messages in the storage.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages in the storage.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> None:
|
||||
"""Clear the memory content."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The update_messages_mark method is not implemented in "
|
||||
f"{self.__class__.__name__} class.",
|
||||
)
|
||||
305
src/agentscope/memory/_working_memory/_in_memory_memory.py
Normal file
305
src/agentscope/memory/_working_memory/_in_memory_memory.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The in-memory storage module for memory storage."""
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from ...message import Msg
|
||||
from ._base import MemoryBase
|
||||
|
||||
|
||||
class InMemoryMemory(MemoryBase):
|
||||
"""The in-memory implementation of memory storage."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory storage."""
|
||||
super().__init__()
|
||||
# Use a list of tuples to store messages along with their marks
|
||||
self.content: list[tuple[Msg, list[str]]] = []
|
||||
|
||||
# Register the state for serialization
|
||||
self.register_state("content")
|
||||
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or None.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
# Type checks
|
||||
if not (mark is None or isinstance(mark, str)):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or None, but got {type(mark)}.",
|
||||
)
|
||||
|
||||
if not (exclude_mark is None or isinstance(exclude_mark, str)):
|
||||
raise TypeError(
|
||||
f"The exclude_mark should be a string or None, but got "
|
||||
f"{type(exclude_mark)}.",
|
||||
)
|
||||
|
||||
# Filter messages based on mark
|
||||
filtered_content = [
|
||||
(msg, marks)
|
||||
for msg, marks in self.content
|
||||
if mark is None or mark in marks
|
||||
]
|
||||
|
||||
# Further filter messages based on exclude_mark
|
||||
if exclude_mark is not None:
|
||||
filtered_content = [
|
||||
(msg, marks)
|
||||
for msg, marks in filtered_content
|
||||
if exclude_mark not in marks
|
||||
]
|
||||
|
||||
if prepend_summary and self._compressed_summary:
|
||||
return [
|
||||
Msg(
|
||||
"user",
|
||||
self._compressed_summary,
|
||||
"user",
|
||||
),
|
||||
*[msg for msg, _ in filtered_content],
|
||||
]
|
||||
|
||||
return [msg for msg, _ in filtered_content]
|
||||
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
allow_duplicates: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message(s) into the memory storage with the given mark
|
||||
(if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg] | None`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
allow_duplicates (`bool`, defaults to `False`):
|
||||
Whether to allow duplicate messages in the storage.
|
||||
"""
|
||||
if memories is None:
|
||||
return
|
||||
|
||||
if isinstance(memories, Msg):
|
||||
memories = [memories]
|
||||
|
||||
if marks is None:
|
||||
marks = []
|
||||
elif isinstance(marks, str):
|
||||
marks = [marks]
|
||||
elif not isinstance(marks, list) or not all(
|
||||
isinstance(m, str) for m in marks
|
||||
):
|
||||
raise TypeError(
|
||||
f"The mark should be a string, a list of strings, or None, "
|
||||
f"but got {type(marks)}.",
|
||||
)
|
||||
|
||||
if not allow_duplicates:
|
||||
existing_ids = {msg.id for msg, _ in self.content}
|
||||
memories = [msg for msg in memories if msg.id not in existing_ids]
|
||||
|
||||
for msg in memories:
|
||||
self.content.append((deepcopy(msg), deepcopy(marks)))
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
initial_size = len(self.content)
|
||||
self.content = [
|
||||
(msg, marks)
|
||||
for msg, marks in self.content
|
||||
if msg.id not in msg_ids
|
||||
]
|
||||
return initial_size - len(self.content)
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the memory by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or a list of strings.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if isinstance(mark, str):
|
||||
mark = [mark]
|
||||
|
||||
if isinstance(mark, list) and not all(
|
||||
isinstance(m, str) for m in mark
|
||||
):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or a list of strings, "
|
||||
f"but got {type(mark)} with elements of types "
|
||||
f"{[type(m) for m in mark]}.",
|
||||
)
|
||||
|
||||
initial_size = len(self.content)
|
||||
for m in mark:
|
||||
self.content = [
|
||||
(msg, marks) for msg, marks in self.content if m not in marks
|
||||
]
|
||||
|
||||
return initial_size - len(self.content)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all messages from the storage."""
|
||||
self.content.clear()
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of messages in the storage.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages in the storage.
|
||||
"""
|
||||
return len(self.content)
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
updated_count = 0
|
||||
|
||||
for idx, (msg, marks) in enumerate(self.content):
|
||||
# If msg_ids is provided, skip messages not in the list
|
||||
if msg_ids is not None and msg.id not in msg_ids:
|
||||
continue
|
||||
|
||||
# If old_mark is provided, skip messages that do not have the old
|
||||
# mark
|
||||
if old_mark is not None and old_mark not in marks:
|
||||
continue
|
||||
|
||||
# If new_mark is None, remove the old_mark
|
||||
if new_mark is None:
|
||||
if old_mark in marks:
|
||||
marks.remove(old_mark)
|
||||
updated_count += 1
|
||||
|
||||
else:
|
||||
# If new_mark is provided, add or replace the old_mark
|
||||
if old_mark is not None and old_mark in marks:
|
||||
marks.remove(old_mark)
|
||||
if new_mark not in marks:
|
||||
marks.append(new_mark)
|
||||
updated_count += 1
|
||||
|
||||
self.content[idx] = (msg, marks)
|
||||
|
||||
return updated_count
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
"""Get the state dictionary for serialization."""
|
||||
return {
|
||||
**super().state_dict(),
|
||||
"content": [[msg.to_dict(), marks] for msg, marks in self.content],
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict, strict: bool = True) -> None:
|
||||
"""Load the state dictionary for deserialization."""
|
||||
if strict and "content" not in state_dict:
|
||||
raise KeyError(
|
||||
"The state_dict does not contain 'content' "
|
||||
"keys required for InMemoryMemory.",
|
||||
)
|
||||
|
||||
self._compressed_summary = state_dict.get("_compressed_summary", "")
|
||||
|
||||
self.content = []
|
||||
for item in state_dict.get("content", []):
|
||||
if isinstance(item, (tuple, list)) and len(item) == 2:
|
||||
msg_dict, marks = item
|
||||
msg = Msg.from_dict(msg_dict)
|
||||
self.content.append((msg, marks))
|
||||
|
||||
elif isinstance(item, dict):
|
||||
# For compatibility with older versions
|
||||
msg = Msg.from_dict(item)
|
||||
self.content.append((msg, []))
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid item format in state_dict for InMemoryMemory.",
|
||||
)
|
||||
827
src/agentscope/memory/_working_memory/_redis_memory.py
Normal file
827
src/agentscope/memory/_working_memory/_redis_memory.py
Normal file
@@ -0,0 +1,827 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The redis based memory storage implementation."""
|
||||
import json
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from ._base import MemoryBase
|
||||
from ...message import Msg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
else:
|
||||
ConnectionPool = Any
|
||||
Redis = Any
|
||||
|
||||
|
||||
class RedisMemory(MemoryBase):
|
||||
"""Redis memory storage implementation, which supports session and user
|
||||
context.
|
||||
|
||||
.. note:: All the operations in this class are within a specific session
|
||||
and user context, identified by `session_id` and `user_id`. Cross-session
|
||||
or cross-user operations are not supported. For example, the
|
||||
`remove_messages` method will only remove messages that belong to the
|
||||
specified `session_id` and `user_id`.
|
||||
|
||||
.. note:: All Redis keys used by this class will be prefixed by `prefix`
|
||||
(if provided) to support multi-tenant / multi-app isolation.
|
||||
|
||||
**Mark Index Storage:**
|
||||
|
||||
This class maintains a `marks_index` (Redis Set) to efficiently track all
|
||||
mark names within a session. When a mark is created via `add_mark()`, the
|
||||
mark name is added to this set. This allows quick retrieval of all marks
|
||||
without scanning all Redis keys. The marks_index key pattern is:
|
||||
``user_id:{user_id}:session:{session_id}:marks_index``
|
||||
|
||||
Each individual mark stores its associated message IDs in a separate Redis
|
||||
List with the key pattern:
|
||||
``user_id:{user_id}:session:{session_id}:mark:{mark}``
|
||||
|
||||
"""
|
||||
|
||||
SESSION_KEY = "user_id:{user_id}:session:{session_id}:messages"
|
||||
"""Redis key pattern (without prefix) for storing message IDs (ordered) for
|
||||
a specific session.
|
||||
"""
|
||||
|
||||
SESSION_PATTERN = "user_id:{user_id}:session:{session_id}:*"
|
||||
"""Redis key pattern (without prefix) for scanning all keys belong to
|
||||
a specific user and session."""
|
||||
|
||||
MARK_KEY = "user_id:{user_id}:session:{session_id}:mark:{mark}"
|
||||
"""Redis key pattern (without prefix) for storing message IDs that belong
|
||||
to a specific mark.
|
||||
"""
|
||||
|
||||
MESSAGE_KEY = "user_id:{user_id}:session:{session_id}:msg:{msg_id}"
|
||||
"""Redis key pattern (without prefix) for storing message payload data."""
|
||||
|
||||
MARKS_INDEX_KEY = "user_id:{user_id}:session:{session_id}:marks_index"
|
||||
"""Redis key pattern (without prefix) for storing all mark names as a set.
|
||||
This is used to avoid scanning all keys to find marks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str = "default_session",
|
||||
user_id: str = "default_user",
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
db: int = 0,
|
||||
password: str | None = None,
|
||||
connection_pool: ConnectionPool | None = None,
|
||||
key_prefix: str = "",
|
||||
key_ttl: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Redis based storage by connecting to the Redis
|
||||
server. You can provide either the connection parameters or an
|
||||
existing connection pool.
|
||||
|
||||
Args:
|
||||
session_id (`str`, default to `"default_session"`):
|
||||
The session ID for the storage.
|
||||
user_id (`str`, default to `"default_user"`):
|
||||
The user ID for the storage.
|
||||
host (`str`, default to `"localhost"`):
|
||||
The Redis server host.
|
||||
port (`int`, default to `6379`):
|
||||
The Redis server port.
|
||||
db (`int`, default to `0`):
|
||||
The Redis database index.
|
||||
password (`str | None`, optional):
|
||||
The password for the Redis server, if required.
|
||||
connection_pool (`ConnectionPool | None`, optional):
|
||||
An optional Redis connection pool. If provided, it will be used
|
||||
instead of creating a new connection.
|
||||
key_prefix (`str`, default to `""`):
|
||||
Optional Redis key prefix prepended to every key generated by
|
||||
this storage. Useful for isolating keys across
|
||||
apps/environments (e.g. `"prod"`, `"staging"`, `"myapp"`).
|
||||
key_ttl (`int | None`, default to `None`):
|
||||
The expired time in seconds for each key. If provided, the
|
||||
expiration will be refreshed on every access (sliding TTL). If
|
||||
`None`, the keys will not expire. **Note** the ttl will update
|
||||
all session related keys, so it's not recommended to set for
|
||||
large sessions.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments to pass to the Redis client.
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The 'redis' package is required for RedisStorage. "
|
||||
"Please install it via 'pip install redis[async]'.",
|
||||
) from e
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.key_prefix = key_prefix or ""
|
||||
self.key_ttl = key_ttl
|
||||
|
||||
self._client = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
connection_pool=connection_pool,
|
||||
decode_responses=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_client(self) -> Redis:
|
||||
"""Get the underlying Redis client.
|
||||
|
||||
Returns:
|
||||
`Redis`:
|
||||
The Redis client instance.
|
||||
"""
|
||||
return self._client
|
||||
|
||||
def _decode_if_bytes(self, data: Any) -> Any:
|
||||
"""Helper method to decode bytes to str if needed.
|
||||
|
||||
Args:
|
||||
data (`Any`):
|
||||
The data to decode, which may be bytes, bytearray, or str.
|
||||
|
||||
Returns:
|
||||
`Any`:
|
||||
The decoded string if input was bytes/bytearray, otherwise
|
||||
the original data.
|
||||
"""
|
||||
if isinstance(data, (bytes, bytearray)):
|
||||
return data.decode("utf-8")
|
||||
return data
|
||||
|
||||
def _decode_list(self, data_list: list) -> list:
|
||||
"""Helper method to decode a list of potential bytes.
|
||||
|
||||
Args:
|
||||
data_list (`list`):
|
||||
A list that may contain bytes, bytearray, or str elements.
|
||||
|
||||
Returns:
|
||||
`list`:
|
||||
A list with all bytes/bytearray elements decoded to str.
|
||||
"""
|
||||
return [self._decode_if_bytes(item) for item in data_list]
|
||||
|
||||
def _get_session_key(self) -> str:
|
||||
"""Get the Redis key for the current session.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing messages in the current session.
|
||||
"""
|
||||
return self.key_prefix + self.SESSION_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _get_session_pattern(self) -> str:
|
||||
"""Get the Redis key pattern for all keys in the current session.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key pattern for all keys in the current session.
|
||||
"""
|
||||
return self.key_prefix + self.SESSION_PATTERN.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _get_mark_key(self, mark: str) -> str:
|
||||
"""Get the Redis key for a specific mark.
|
||||
|
||||
Args:
|
||||
mark (`str`):
|
||||
The mark name.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing message IDs with the given mark.
|
||||
"""
|
||||
return self.key_prefix + self.MARK_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
mark=mark,
|
||||
)
|
||||
|
||||
def _get_mark_pattern(self) -> str:
|
||||
"""Get the Redis key pattern for all marks in the current session.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key pattern for all mark keys.
|
||||
"""
|
||||
return self.key_prefix + self.MARK_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
mark="*",
|
||||
)
|
||||
|
||||
def _get_marks_index_key(self) -> str:
|
||||
"""Get the Redis key for the marks index set.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing all mark names as a set.
|
||||
"""
|
||||
return self.key_prefix + self.MARKS_INDEX_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
def _extract_mark_from_key(self, mark_key: str) -> str:
|
||||
"""Extract the mark name from a full mark key.
|
||||
|
||||
Args:
|
||||
mark_key (`str`):
|
||||
The full Redis key for a mark.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The mark name extracted from the key.
|
||||
"""
|
||||
# Remove the prefix and the base pattern to get the mark name
|
||||
# Example: "prefix:user_id:xxx:session:yyy:mark:my_mark" -> "my_mark"
|
||||
prefix_pattern = self.key_prefix + self.MARK_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
mark="",
|
||||
)
|
||||
return mark_key.replace(prefix_pattern, "")
|
||||
|
||||
def _get_message_key(self, msg_id: str) -> str:
|
||||
"""Get the Redis key for a specific message.
|
||||
|
||||
Args:
|
||||
msg_id (`str`):
|
||||
The message ID.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The Redis key for storing the message data.
|
||||
"""
|
||||
return self.key_prefix + self.MESSAGE_KEY.format(
|
||||
user_id=self.user_id,
|
||||
session_id=self.session_id,
|
||||
msg_id=msg_id,
|
||||
)
|
||||
|
||||
async def _refresh_session_ttl(
|
||||
self,
|
||||
pipe: Any | None = None,
|
||||
) -> None:
|
||||
"""Refresh the TTL for the session keys (if `key_ttl` is set).
|
||||
|
||||
Args:
|
||||
pipe (`Any | None`, optional):
|
||||
An optional Redis pipeline to use. If `None`, a new pipeline
|
||||
will be created and executed immediately. If provided, the
|
||||
expired commands will be added to the pipeline without
|
||||
executing it.
|
||||
"""
|
||||
if self.key_ttl is None:
|
||||
return
|
||||
|
||||
# Create a new pipeline if not provided
|
||||
should_execute = pipe is None
|
||||
if pipe is None:
|
||||
pipe = self._client.pipeline()
|
||||
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(
|
||||
cursor,
|
||||
match=self._get_session_pattern(),
|
||||
count=100,
|
||||
)
|
||||
# Decode keys if they are bytes
|
||||
keys = self._decode_list(keys)
|
||||
for key in keys:
|
||||
await pipe.expire(key, self.key_ttl)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if should_execute:
|
||||
await pipe.execute()
|
||||
|
||||
async def _scan_and_migrate_marks(self) -> list[str]:
|
||||
"""Scan all mark keys and migrate them to the marks index.
|
||||
|
||||
This method is only called once for old data that doesn't have
|
||||
a marks index yet. After migration, the marks index will be
|
||||
maintained automatically.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
The list of all mark keys found.
|
||||
"""
|
||||
mark_keys = []
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self._client.scan(
|
||||
cursor,
|
||||
match=self._get_mark_pattern(),
|
||||
count=50,
|
||||
)
|
||||
keys = self._decode_list(keys)
|
||||
mark_keys.extend(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Build the marks index
|
||||
if mark_keys:
|
||||
pipe = self._client.pipeline()
|
||||
for mark_key in mark_keys:
|
||||
mark = self._extract_mark_from_key(mark_key)
|
||||
await pipe.sadd(self._get_marks_index_key(), mark)
|
||||
await pipe.execute()
|
||||
|
||||
return mark_keys
|
||||
|
||||
async def _get_all_mark_keys(self) -> list[str]:
|
||||
"""Get all mark keys, compatible with both old and new data.
|
||||
|
||||
For new data (with marks index), this method uses the index directly.
|
||||
For old data (without marks index), this method scans once and
|
||||
migrates to the new structure.
|
||||
|
||||
Returns:
|
||||
`list[str]`:
|
||||
The list of all mark keys.
|
||||
"""
|
||||
marks_index_key = self._get_marks_index_key()
|
||||
|
||||
# Try to read from the index first
|
||||
marks = await self._client.smembers(marks_index_key)
|
||||
if marks:
|
||||
# Index exists, use it
|
||||
marks = self._decode_list(list(marks))
|
||||
return [self._get_mark_key(mark) for mark in marks]
|
||||
|
||||
# Index doesn't exist, check if this is a new session
|
||||
session_exists = await self._client.exists(self._get_session_key())
|
||||
if not session_exists:
|
||||
# New session, no data at all, return empty
|
||||
return []
|
||||
|
||||
# Old session without index, need to scan and migrate (only once)
|
||||
mark_keys = await self._scan_and_migrate_marks()
|
||||
return mark_keys
|
||||
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
# Type checks
|
||||
if not (mark is None or isinstance(mark, str)):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or None, but got {type(mark)}.",
|
||||
)
|
||||
|
||||
if not (exclude_mark is None or isinstance(exclude_mark, str)):
|
||||
raise TypeError(
|
||||
f"The exclude_mark should be a string or None, but got "
|
||||
f"{type(exclude_mark)}.",
|
||||
)
|
||||
|
||||
if mark is None:
|
||||
# Obtain the message IDs from the session list
|
||||
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
|
||||
|
||||
else:
|
||||
# Obtain the message IDs from the mark list
|
||||
msg_ids = await self._client.lrange(
|
||||
self._get_mark_key(mark),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
|
||||
msg_ids = self._decode_list(msg_ids)
|
||||
|
||||
# Exclude messages by exclude_mark
|
||||
if exclude_mark:
|
||||
exclude_msg_ids = await self._client.lrange(
|
||||
self._get_mark_key(exclude_mark),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
exclude_msg_ids = self._decode_list(exclude_msg_ids)
|
||||
msg_ids = [_ for _ in msg_ids if _ not in exclude_msg_ids]
|
||||
|
||||
# Use mget for batch retrieval to avoid N+1 queries
|
||||
messages: list[Msg] = []
|
||||
if msg_ids:
|
||||
msg_keys = [self._get_message_key(msg_id) for msg_id in msg_ids]
|
||||
msg_data_list = await self._client.mget(msg_keys)
|
||||
|
||||
for msg_data in msg_data_list:
|
||||
if msg_data is not None:
|
||||
# Decode if bytes
|
||||
msg_data = self._decode_if_bytes(msg_data)
|
||||
msg_dict = json.loads(msg_data)
|
||||
messages.append(Msg.from_dict(msg_dict))
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl()
|
||||
|
||||
if prepend_summary and self._compressed_summary:
|
||||
return [
|
||||
Msg(
|
||||
"user",
|
||||
self._compressed_summary,
|
||||
"user",
|
||||
),
|
||||
*messages,
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
skip_duplicated: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message into the storage with the given mark (if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg]`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
skip_duplicated (`bool`, defaults to `True`):
|
||||
If `True`, skip messages with duplicate IDs that already exist
|
||||
in the storage. If `False`, allow duplicate message IDs to be
|
||||
added to the session list (though the message data will be
|
||||
overwritten).
|
||||
"""
|
||||
if memories is None:
|
||||
return
|
||||
|
||||
if isinstance(memories, Msg):
|
||||
memories = [memories]
|
||||
|
||||
# Normalize marks to a list
|
||||
if marks is None:
|
||||
mark_list = []
|
||||
elif isinstance(marks, str):
|
||||
mark_list = [marks]
|
||||
else:
|
||||
mark_list = marks
|
||||
|
||||
# Filter out existing messages if skip_duplicated is True
|
||||
messages_to_add = memories
|
||||
if skip_duplicated:
|
||||
# Get all existing message IDs in the current session
|
||||
existing_msg_ids = await self._client.lrange(
|
||||
self._get_session_key(),
|
||||
0,
|
||||
-1,
|
||||
)
|
||||
existing_msg_ids = self._decode_list(existing_msg_ids)
|
||||
existing_msg_ids_set = set(existing_msg_ids)
|
||||
|
||||
# Filter out messages that already exist
|
||||
messages_to_add = [
|
||||
m for m in memories if m.id not in existing_msg_ids_set
|
||||
]
|
||||
|
||||
# If all messages are duplicates, return early
|
||||
if not messages_to_add:
|
||||
return
|
||||
|
||||
# Use pipeline for atomic operations
|
||||
pipe = self._client.pipeline()
|
||||
|
||||
# Push message ids into the session list
|
||||
if messages_to_add:
|
||||
await pipe.rpush(
|
||||
self._get_session_key(),
|
||||
*[m.id for m in messages_to_add],
|
||||
)
|
||||
|
||||
# Store message data and marks
|
||||
for m in messages_to_add:
|
||||
# Record the message data
|
||||
await pipe.set(
|
||||
self._get_message_key(m.id),
|
||||
json.dumps(m.to_dict(), ensure_ascii=False),
|
||||
)
|
||||
|
||||
# Record the marks if provided
|
||||
for mark in mark_list:
|
||||
await pipe.rpush(self._get_mark_key(mark), m.id)
|
||||
# Maintain the marks index
|
||||
await pipe.sadd(self._get_marks_index_key(), mark)
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl(pipe=pipe)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if not msg_ids:
|
||||
return 0
|
||||
|
||||
# Get all mark keys using the new method (compatible with old data)
|
||||
mark_keys = await self._get_all_mark_keys()
|
||||
|
||||
pipe = self._client.pipeline()
|
||||
for msg_id in msg_ids:
|
||||
# Remove from the session (0 means remove all occurrences)
|
||||
await pipe.lrem(self._get_session_key(), 0, msg_id)
|
||||
|
||||
# Remove the message data
|
||||
await pipe.delete(self._get_message_key(msg_id))
|
||||
|
||||
# Remove from all marks
|
||||
for mark_key in mark_keys:
|
||||
await pipe.lrem(mark_key, 0, msg_id)
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl(pipe=pipe)
|
||||
|
||||
results = await pipe.execute()
|
||||
|
||||
# Count actual deletions from lrem results (every 3rd result
|
||||
# starting from 0)
|
||||
removed_count = sum(
|
||||
1
|
||||
for i in range(
|
||||
0,
|
||||
len(msg_ids) * (2 + len(mark_keys)),
|
||||
2 + len(mark_keys),
|
||||
)
|
||||
if results[i] > 0
|
||||
)
|
||||
|
||||
return removed_count
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the storage by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if isinstance(mark, str):
|
||||
mark = [mark]
|
||||
|
||||
total_removed = 0
|
||||
|
||||
for m in mark:
|
||||
mark_key = self._get_mark_key(m)
|
||||
msg_ids = await self._client.lrange(mark_key, 0, -1)
|
||||
msg_ids = self._decode_list(msg_ids)
|
||||
|
||||
if not msg_ids:
|
||||
continue
|
||||
|
||||
# Remove messages by IDs
|
||||
removed_count = await self.delete(
|
||||
msg_ids,
|
||||
)
|
||||
total_removed += removed_count
|
||||
|
||||
# Delete the mark list
|
||||
await self._client.delete(mark_key)
|
||||
|
||||
# Remove from the marks index
|
||||
await self._client.srem(self._get_marks_index_key(), m)
|
||||
|
||||
# Refresh TTLs
|
||||
await self._refresh_session_ttl()
|
||||
|
||||
return total_removed
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all messages belong to this session from the storage."""
|
||||
msg_ids = await self._client.lrange(self._get_session_key(), 0, -1)
|
||||
msg_ids = self._decode_list(msg_ids)
|
||||
|
||||
# Get all mark keys using the new method (compatible with old data)
|
||||
mark_keys = await self._get_all_mark_keys()
|
||||
|
||||
pipe = self._client.pipeline()
|
||||
|
||||
for msg_id in msg_ids:
|
||||
# Remove the message data
|
||||
await pipe.delete(self._get_message_key(msg_id))
|
||||
|
||||
# Delete the session list
|
||||
await pipe.delete(self._get_session_key())
|
||||
|
||||
# Delete all mark lists
|
||||
for mark_key in mark_keys:
|
||||
await pipe.delete(mark_key)
|
||||
|
||||
# Delete the marks index
|
||||
await pipe.delete(self._get_marks_index_key())
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the number of messages in the storage.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages in the storage.
|
||||
"""
|
||||
size = await self._client.llen(self._get_session_key())
|
||||
await self._refresh_session_ttl()
|
||||
return size
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
# Determine which message IDs to update
|
||||
# Get source key based on old_mark
|
||||
source_key = (
|
||||
self._get_mark_key(old_mark)
|
||||
if old_mark is not None
|
||||
else self._get_session_key()
|
||||
)
|
||||
mark_msg_ids = await self._client.lrange(source_key, 0, -1)
|
||||
mark_msg_ids = self._decode_list(mark_msg_ids)
|
||||
|
||||
# Check if we're removing all messages from old_mark
|
||||
removing_all_from_old_mark = old_mark is not None and (
|
||||
msg_ids is None or all(mid in set(msg_ids) for mid in mark_msg_ids)
|
||||
)
|
||||
|
||||
# Filter by msg_ids if provided
|
||||
if msg_ids is not None:
|
||||
msg_ids_set = set(msg_ids)
|
||||
mark_msg_ids = [mid for mid in mark_msg_ids if mid in msg_ids_set]
|
||||
|
||||
if not mark_msg_ids:
|
||||
return 0
|
||||
|
||||
# Get existing IDs in new_mark list once (if needed)
|
||||
existing_ids_set = set()
|
||||
new_mark_key = None
|
||||
if new_mark is not None:
|
||||
new_mark_key = self._get_mark_key(new_mark)
|
||||
existing_ids = await self._client.lrange(new_mark_key, 0, -1)
|
||||
existing_ids = self._decode_list(existing_ids)
|
||||
existing_ids_set = set(existing_ids)
|
||||
|
||||
# Use pipeline for batch operations
|
||||
pipe = self._client.pipeline()
|
||||
updated_count = 0
|
||||
|
||||
for msg_id in mark_msg_ids:
|
||||
# Remove from old_mark list if applicable
|
||||
if old_mark is not None:
|
||||
await pipe.lrem(
|
||||
self._get_mark_key(old_mark),
|
||||
0,
|
||||
msg_id,
|
||||
)
|
||||
|
||||
# Add to new_mark list only if not already present
|
||||
if new_mark is not None and msg_id not in existing_ids_set:
|
||||
await pipe.rpush(new_mark_key, msg_id)
|
||||
existing_ids_set.add(msg_id)
|
||||
# Maintain the marks index
|
||||
await pipe.sadd(self._get_marks_index_key(), new_mark)
|
||||
|
||||
# Count update only if we actually did something
|
||||
if old_mark is not None or new_mark is not None:
|
||||
updated_count += 1
|
||||
|
||||
# Clean up old_mark only if we removed ALL messages from it
|
||||
if old_mark is not None and removing_all_from_old_mark:
|
||||
old_mark_key = self._get_mark_key(old_mark)
|
||||
# After lrem operations, the old mark list will be empty
|
||||
# Delete the mark key and remove from index
|
||||
await pipe.delete(old_mark_key)
|
||||
await pipe.srem(self._get_marks_index_key(), old_mark)
|
||||
|
||||
await self._refresh_session_ttl(pipe=pipe)
|
||||
|
||||
await pipe.execute()
|
||||
return updated_count
|
||||
|
||||
async def close(self, close_connection_pool: bool | None = None) -> None:
|
||||
"""Close the Redis client connection.
|
||||
|
||||
Args:
|
||||
close_connection_pool (`bool | None`, optional):
|
||||
Decides whether to close the connection pool used by this
|
||||
Redis client, overriding Redis.auto_close_connection_pool.
|
||||
By default, let Redis.auto_close_connection_pool decide
|
||||
whether to close the connection pool
|
||||
"""
|
||||
await self._client.aclose(close_connection_pool=close_connection_pool)
|
||||
|
||||
async def __aenter__(self) -> "RedisMemory":
|
||||
"""Enter the async context manager.
|
||||
|
||||
Returns:
|
||||
`RedisMemory`:
|
||||
The memory instance itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: Any,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the session.
|
||||
|
||||
Args:
|
||||
exc_type (`type[BaseException] | None`):
|
||||
The exception type if an exception was raised.
|
||||
exc_value (`BaseException | None`):
|
||||
The exception instance if an exception was raised.
|
||||
traceback (`Any`):
|
||||
The traceback object if an exception was raised.
|
||||
"""
|
||||
await self.close()
|
||||
873
src/agentscope/memory/_working_memory/_sqlalchemy_memory.py
Normal file
873
src/agentscope/memory/_working_memory/_sqlalchemy_memory.py
Normal file
@@ -0,0 +1,873 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The SQLAlchemy database storage module, which supports storing messages in
|
||||
a SQL database using SQLAlchemy ORM (e.g., SQLite, PostgreSQL, MySQL)."""
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
String,
|
||||
JSON,
|
||||
BigInteger,
|
||||
ForeignKey,
|
||||
select,
|
||||
delete,
|
||||
update,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ._base import MemoryBase
|
||||
from ...message import Msg
|
||||
|
||||
Base: Any = declarative_base()
|
||||
|
||||
|
||||
class AsyncSQLAlchemyMemory(MemoryBase):
|
||||
"""The SQLAlchemy memory storage class for storing messages in a SQL
|
||||
database using SQLAlchemy ORM, such as SQLite, PostgreSQL, MySQL, etc.
|
||||
|
||||
.. note:: All the operations in this class are within a specific session
|
||||
and user context, identified by `session_id` and `user_id`. Cross-session
|
||||
or cross-user operations are not supported. For example, the
|
||||
`remove_messages` method will only remove messages that belong to the
|
||||
specified `session_id` and `user_id`.
|
||||
|
||||
"""
|
||||
|
||||
class MessageTable(Base):
|
||||
"""The default message table definition."""
|
||||
|
||||
__tablename__ = "message"
|
||||
"""The table name"""
|
||||
|
||||
id = Column(String(255), primary_key=True)
|
||||
"""The id column, we use the f"{user_id}-{session_id}-{message_id}"
|
||||
as the primary key to ensure uniqueness across users and sessions."""
|
||||
|
||||
msg = Column(JSON, nullable=False)
|
||||
"""The message JSON content column"""
|
||||
|
||||
session = relationship(
|
||||
"SessionTable",
|
||||
back_populates="messages",
|
||||
)
|
||||
"""The foreign key to the session id relationship"""
|
||||
|
||||
session_id = Column(
|
||||
String(255),
|
||||
ForeignKey("session.id"),
|
||||
nullable=False,
|
||||
)
|
||||
"""The foreign key to the session id"""
|
||||
|
||||
index = Column(BigInteger, nullable=False, index=True)
|
||||
"""The index column for ordering messages, so that we can retrieve
|
||||
messages in the order they were added."""
|
||||
|
||||
class MessageMarkTable(Base):
|
||||
"""The default message mark table definition."""
|
||||
|
||||
__tablename__ = "message_mark"
|
||||
"""The table name"""
|
||||
|
||||
msg_id = Column(
|
||||
String(255),
|
||||
ForeignKey("message.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
"""The message id column"""
|
||||
|
||||
mark = Column(String(255), primary_key=True)
|
||||
"""The mark column"""
|
||||
|
||||
class SessionTable(Base):
|
||||
"""The default session table definition."""
|
||||
|
||||
__tablename__ = "session"
|
||||
"""The table name"""
|
||||
|
||||
id = Column(String(255), primary_key=True)
|
||||
"""The session id column"""
|
||||
|
||||
user = relationship("UserTable", back_populates="sessions")
|
||||
"""The foreign key to the user id relationship"""
|
||||
|
||||
user_id = Column(String(255), ForeignKey("users.id"), nullable=False)
|
||||
"""The foreign key to the user id"""
|
||||
|
||||
messages = relationship("MessageTable", back_populates="session")
|
||||
"""The relationship to messages"""
|
||||
|
||||
class UserTable(Base):
|
||||
"""The default user table definition."""
|
||||
|
||||
__tablename__ = "users"
|
||||
"""The table name"""
|
||||
|
||||
id = Column(String(255), primary_key=True)
|
||||
"""The user id column"""
|
||||
|
||||
sessions = relationship("SessionTable", back_populates="user")
|
||||
"""The relationship to sessions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_or_session: AsyncEngine | AsyncSession,
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the SqlAlchemyDBStorage with a SQLAlchemy session.
|
||||
|
||||
Args:
|
||||
engine_or_session (`AsyncEngine | AsyncSession`):
|
||||
The SQLAlchemy asynchronous engine or session to use for
|
||||
database operations. If you're using a connection pool, maybe
|
||||
you want to pass in an `AsyncSession` instance.
|
||||
session_id (`str | None`, optional):
|
||||
The session ID for the messages. If `None`, a default session
|
||||
ID will be used.
|
||||
user_id (`str | None`, optional):
|
||||
The user ID for the messages. If `None`, a default user ID
|
||||
will be used.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the `engine` parameter is not an instance of
|
||||
`sqlalchemy.ext.asyncio.AsyncEngine` or `sqlalchemy.
|
||||
ext.asyncio.AsyncSession`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._db_session: AsyncSession | None = None
|
||||
|
||||
if isinstance(engine_or_session, AsyncEngine):
|
||||
self._session_factory = async_sessionmaker(
|
||||
bind=engine_or_session,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
elif isinstance(engine_or_session, AsyncSession):
|
||||
self._session_factory = None
|
||||
self._db_session = engine_or_session
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"The 'engine_or_session' parameter must be an instance of "
|
||||
"sqlalchemy.ext.asyncio.AsyncEngine.",
|
||||
)
|
||||
|
||||
self.session_id = session_id or "default_session"
|
||||
self.user_id = user_id or "default_user"
|
||||
|
||||
# Flag to track if tables and records have been initialized
|
||||
self._initialized = False
|
||||
|
||||
def _make_message_id(self, msg_id: str) -> str:
|
||||
"""Generate a composite primary key for a message.
|
||||
|
||||
Args:
|
||||
msg_id (`str`):
|
||||
The original message ID.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
The composite primary key in the format
|
||||
"{user_id}-{session_id}-{message_id}".
|
||||
"""
|
||||
return f"{self.user_id}-{self.session_id}-{msg_id}"
|
||||
|
||||
@property
|
||||
def session(self) -> AsyncSession:
|
||||
"""Get the current database session, creating one if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
`AsyncSession`:
|
||||
The current database session.
|
||||
|
||||
Note:
|
||||
- If an external session was provided, it will be returned as-is
|
||||
- If using internal session factory, a new session will be created
|
||||
if the current one is None or inactive, and _initialized flag
|
||||
will be reset to ensure proper re-initialization
|
||||
"""
|
||||
# External session: return as-is (managed by caller)
|
||||
if self._session_factory is None:
|
||||
return self._db_session
|
||||
|
||||
# Internal session: check validity and recreate if needed
|
||||
if self._db_session is None or not self._db_session.is_active:
|
||||
self._db_session = self._session_factory()
|
||||
# Reset initialized flag when creating new session
|
||||
self._initialized = False
|
||||
|
||||
return self._db_session
|
||||
|
||||
async def _create_table(self) -> None:
|
||||
"""Create tables in database."""
|
||||
# Skip if already initialized
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Obtain the engine first
|
||||
engine: AsyncEngine = self.session.bind
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Track if we need to commit
|
||||
needs_commit = False
|
||||
|
||||
# Create user record if not exists
|
||||
result = await self.session.execute(
|
||||
select(self.UserTable).filter(
|
||||
self.UserTable.id == self.user_id,
|
||||
),
|
||||
)
|
||||
user_record = result.scalar_one_or_none()
|
||||
|
||||
if user_record is None:
|
||||
user_record = self.UserTable(
|
||||
id=self.user_id,
|
||||
)
|
||||
self.session.add(user_record)
|
||||
needs_commit = True
|
||||
|
||||
# Create session record if not exists
|
||||
result = await self.session.execute(
|
||||
select(self.SessionTable).filter(
|
||||
self.SessionTable.id == self.session_id,
|
||||
),
|
||||
)
|
||||
session_record = result.scalar_one_or_none()
|
||||
|
||||
if session_record is None:
|
||||
session_record = self.SessionTable(
|
||||
id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
self.session.add(session_record)
|
||||
needs_commit = True
|
||||
|
||||
# Commit once if any records were added
|
||||
if needs_commit:
|
||||
await self.session.commit()
|
||||
|
||||
# Mark as initialized
|
||||
self._initialized = True
|
||||
|
||||
async def get_memory(
|
||||
self,
|
||||
mark: str | None = None,
|
||||
exclude_mark: str | None = None,
|
||||
prepend_summary: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> list[Msg]:
|
||||
"""Get the messages from the memory by mark (if provided). Otherwise,
|
||||
get all messages.
|
||||
|
||||
.. note:: If `mark` and `exclude_mark` are both provided, the messages
|
||||
will be filtered by both arguments.
|
||||
|
||||
.. note:: `mark` and `exclude_mark` should not overlap.
|
||||
|
||||
Args:
|
||||
mark (`str | None`, optional):
|
||||
The mark to filter messages. If `None`, return all messages.
|
||||
exclude_mark (`str | None`, optional):
|
||||
The mark to exclude messages. If provided, messages with
|
||||
this mark will be excluded from the results.
|
||||
prepend_summary (`bool`, defaults to True):
|
||||
Whether to prepend the compressed summary as a message
|
||||
|
||||
Raises:
|
||||
`TypeError`:
|
||||
If the provided mark is not a string or None.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The list of messages retrieved from the storage.
|
||||
"""
|
||||
# Type checks
|
||||
if mark is not None and not isinstance(mark, str):
|
||||
raise TypeError(
|
||||
f"The mark should be a string or None, but got {type(mark)}.",
|
||||
)
|
||||
|
||||
if exclude_mark is not None and not isinstance(exclude_mark, str):
|
||||
raise TypeError(
|
||||
f"The exclude_mark should be a string or None, but got "
|
||||
f"{type(exclude_mark)}.",
|
||||
)
|
||||
|
||||
await self._create_table()
|
||||
|
||||
# Step 1: First filter by session_id to narrow down the dataset
|
||||
# This ensures the database uses the session_id index first
|
||||
base_query = select(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
)
|
||||
|
||||
# Step 2: Apply mark filtering if provided
|
||||
if mark:
|
||||
# Join with mark table only on the session-filtered messages
|
||||
base_query = base_query.join(
|
||||
self.MessageMarkTable,
|
||||
self.MessageTable.id == self.MessageMarkTable.msg_id,
|
||||
).filter(
|
||||
self.MessageMarkTable.mark == mark,
|
||||
)
|
||||
|
||||
# Step 3: Apply exclude_mark filtering if provided
|
||||
if exclude_mark:
|
||||
# Use a subquery to find message IDs with the exclude_mark
|
||||
# within the current session only
|
||||
exclude_subquery = (
|
||||
select(self.MessageMarkTable.msg_id)
|
||||
.filter(
|
||||
self.MessageMarkTable.msg_id.in_(
|
||||
select(self.MessageTable.id).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
),
|
||||
self.MessageMarkTable.mark == exclude_mark,
|
||||
)
|
||||
.scalar_subquery()
|
||||
)
|
||||
# Exclude messages whose IDs are in the subquery
|
||||
base_query = base_query.filter(
|
||||
self.MessageTable.id.notin_(exclude_subquery),
|
||||
)
|
||||
|
||||
# Step 4: Order by index to maintain message order
|
||||
query = base_query.order_by(self.MessageTable.index)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
results = result.scalars().all()
|
||||
|
||||
msgs = [Msg.from_dict(result.msg) for result in results]
|
||||
if prepend_summary and self._compressed_summary:
|
||||
return [
|
||||
Msg(
|
||||
"user",
|
||||
self._compressed_summary,
|
||||
"user",
|
||||
),
|
||||
*msgs,
|
||||
]
|
||||
|
||||
return msgs
|
||||
|
||||
async def add(
|
||||
self,
|
||||
memories: Msg | list[Msg] | None,
|
||||
marks: str | list[str] | None = None,
|
||||
skip_duplicated: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add message into the storage with the given mark (if provided).
|
||||
|
||||
Args:
|
||||
memories (`Msg | list[Msg] | None`):
|
||||
The message(s) to be added.
|
||||
marks (`str | list[str] | None`, optional):
|
||||
The mark(s) to associate with the message(s). If `None`, no
|
||||
mark is associated.
|
||||
skip_duplicated (`bool`, defaults to `True`):
|
||||
If `True`, skip messages with duplicate IDs that already exist
|
||||
in the storage. If `False`, raise an `IntegrityError` when
|
||||
attempting to add a message with an existing ID.
|
||||
|
||||
Raises:
|
||||
`IntegrityError`:
|
||||
If a message with the same ID already exists in the storage
|
||||
and `skip_duplicated` is set to `False`.
|
||||
"""
|
||||
if memories is None:
|
||||
return
|
||||
|
||||
# Type checking
|
||||
if isinstance(memories, Msg):
|
||||
memories = [memories]
|
||||
elif not (
|
||||
isinstance(memories, list)
|
||||
and all(isinstance(_, Msg) for _ in memories)
|
||||
):
|
||||
raise TypeError(
|
||||
"The 'memories' parameter must be a Msg instance or a list of "
|
||||
f"Msg instances, but got {type(memories)}.",
|
||||
)
|
||||
|
||||
if isinstance(marks, str):
|
||||
marks = [marks]
|
||||
elif marks is not None and not (
|
||||
isinstance(marks, list) and all(isinstance(m, str) for m in marks)
|
||||
):
|
||||
raise TypeError(
|
||||
"The 'marks' parameter must be a string or a list of strings, "
|
||||
f"but got {type(marks)}.",
|
||||
)
|
||||
|
||||
# Create table if not exists
|
||||
await self._create_table()
|
||||
|
||||
# If skip_duplicated is True, filter out existing messages
|
||||
messages_to_add = memories
|
||||
if skip_duplicated:
|
||||
existing_msg_ids = set()
|
||||
result = await self.session.execute(
|
||||
select(self.MessageTable.id).filter(
|
||||
self.MessageTable.id.in_(
|
||||
[self._make_message_id(m.id) for m in memories],
|
||||
),
|
||||
),
|
||||
)
|
||||
existing_msg_ids = {row[0] for row in result.fetchall()}
|
||||
|
||||
messages_to_add = [
|
||||
m
|
||||
for m in memories
|
||||
if self._make_message_id(m.id) not in existing_msg_ids
|
||||
]
|
||||
|
||||
# If all messages are duplicates, return early
|
||||
if not messages_to_add:
|
||||
return
|
||||
|
||||
# Get the starting index once to avoid race conditions
|
||||
start_index = await self._get_next_index()
|
||||
|
||||
# Add messages to message table
|
||||
for i, m in enumerate(messages_to_add):
|
||||
message_record = self.MessageTable(
|
||||
id=self._make_message_id(m.id),
|
||||
msg=m.to_dict(),
|
||||
session_id=self.session_id,
|
||||
index=start_index + i,
|
||||
)
|
||||
self.session.add(message_record)
|
||||
|
||||
# Create mark records if marks are provided (use bulk insert)
|
||||
if marks:
|
||||
mark_records = [
|
||||
{"msg_id": self._make_message_id(msg.id), "mark": mark}
|
||||
for msg in messages_to_add
|
||||
for mark in marks
|
||||
]
|
||||
if mark_records:
|
||||
if skip_duplicated:
|
||||
# Query existing mark combinations to avoid duplicates
|
||||
result = await self.session.execute(
|
||||
select(
|
||||
self.MessageMarkTable.msg_id,
|
||||
self.MessageMarkTable.mark,
|
||||
),
|
||||
)
|
||||
existing_marks = {
|
||||
(row[0], row[1]) for row in result.fetchall()
|
||||
}
|
||||
|
||||
# Filter out existing mark combinations
|
||||
mark_records = [
|
||||
r
|
||||
for r in mark_records
|
||||
if (r["msg_id"], r["mark"]) not in existing_marks
|
||||
]
|
||||
|
||||
if mark_records:
|
||||
await self.session.run_sync(
|
||||
lambda session: session.bulk_insert_mappings(
|
||||
self.MessageMarkTable,
|
||||
mark_records,
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
async def _get_next_index(self) -> int:
|
||||
"""Get the next index for a new message in the current session.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The next index value.
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(self.MessageTable.index)
|
||||
.filter(self.MessageTable.session_id == self.session_id)
|
||||
.order_by(self.MessageTable.index.desc())
|
||||
.limit(1),
|
||||
)
|
||||
max_index = result.scalar_one_or_none()
|
||||
return (max_index + 1) if max_index is not None else 0
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get the size of the messages in the storage."""
|
||||
result = await self.session.execute(
|
||||
select(func.count(self.MessageTable.id)).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all messages from the storage."""
|
||||
# Delete all marks for messages in this session
|
||||
await self.session.execute(
|
||||
delete(self.MessageMarkTable).where(
|
||||
self.MessageMarkTable.msg_id.in_(
|
||||
select(self.MessageTable.id).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Then delete all messages
|
||||
await self.session.execute(
|
||||
delete(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
|
||||
async def delete_by_mark(
|
||||
self,
|
||||
mark: str | list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove messages from the storage by their marks.
|
||||
|
||||
Args:
|
||||
mark (`str | list[str]`):
|
||||
The mark(s) of the messages to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
if isinstance(mark, str):
|
||||
mark = [mark]
|
||||
|
||||
# First, find message IDs that have the specified marks
|
||||
query = (
|
||||
select(self.MessageTable.id)
|
||||
.join(
|
||||
self.MessageMarkTable,
|
||||
self.MessageTable.id == self.MessageMarkTable.msg_id,
|
||||
)
|
||||
.filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
self.MessageMarkTable.mark.in_(mark),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
msg_ids = [row[0] for row in result.all()]
|
||||
|
||||
if not msg_ids:
|
||||
return 0
|
||||
|
||||
# Store the count before deletion
|
||||
deleted_count = len(msg_ids)
|
||||
|
||||
# Delete marks first
|
||||
await self.session.execute(
|
||||
delete(self.MessageMarkTable).filter(
|
||||
self.MessageMarkTable.msg_id.in_(msg_ids),
|
||||
),
|
||||
)
|
||||
|
||||
# Then delete the messages
|
||||
await self.session.execute(
|
||||
delete(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
self.MessageTable.id.in_(msg_ids),
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
return deleted_count
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
**kwargs: Any,
|
||||
) -> int:
|
||||
"""Remove message(s) from the storage by their IDs.
|
||||
|
||||
.. note:: Although MessageMarkTable has CASCADE delete on foreign key,
|
||||
we explicitly delete marks first for reliability across all database
|
||||
engines and configurations. SQLAlchemy's bulk delete bypasses
|
||||
ORM-level cascades, and SQLite requires foreign keys to be
|
||||
explicitly enabled.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be removed.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages removed.
|
||||
"""
|
||||
# Convert to composite keys
|
||||
composite_ids = [self._make_message_id(msg_id) for msg_id in msg_ids]
|
||||
|
||||
if not composite_ids:
|
||||
return 0
|
||||
|
||||
# Store the count before deletion
|
||||
deleted_count = len(composite_ids)
|
||||
|
||||
# Delete related marks first (explicit cleanup for reliability)
|
||||
await self.session.execute(
|
||||
delete(self.MessageMarkTable).filter(
|
||||
self.MessageMarkTable.msg_id.in_(composite_ids),
|
||||
),
|
||||
)
|
||||
|
||||
# Then delete the messages
|
||||
await self.session.execute(
|
||||
delete(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
self.MessageTable.id.in_(composite_ids),
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
return deleted_count
|
||||
|
||||
async def update_messages_mark(
|
||||
self,
|
||||
new_mark: str | None,
|
||||
old_mark: str | None = None,
|
||||
msg_ids: list[str] | None = None,
|
||||
) -> int:
|
||||
"""A unified method to update marks of messages in the storage (add,
|
||||
remove, or change marks).
|
||||
|
||||
- If `msg_ids` is provided, the update will be applied to the messages
|
||||
with the specified IDs.
|
||||
- If `old_mark` is provided, the update will be applied to the
|
||||
messages with the specified old mark. Otherwise, the `new_mark` will
|
||||
be added to all messages (or those filtered by `msg_ids`).
|
||||
- If `new_mark` is `None`, the mark will be removed from the messages.
|
||||
|
||||
Args:
|
||||
new_mark (`str | None`, optional):
|
||||
The new mark to set for the messages. If `None`, the mark
|
||||
will be removed.
|
||||
old_mark (`str | None`, optional):
|
||||
The old mark to filter messages. If `None`, this constraint
|
||||
is ignored.
|
||||
msg_ids (`list[str] | None`, optional):
|
||||
The list of message IDs to be updated. If `None`, this
|
||||
constraint is ignored.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
|
||||
# Type checking
|
||||
if new_mark is not None and not isinstance(new_mark, str):
|
||||
raise ValueError(
|
||||
f"The 'new_mark' parameter must be a string or None, "
|
||||
f"but got {type(new_mark)}.",
|
||||
)
|
||||
|
||||
if old_mark is not None and not isinstance(old_mark, str):
|
||||
raise ValueError(
|
||||
f"The 'old_mark' parameter must be a string or None, "
|
||||
f"but got {type(old_mark)}.",
|
||||
)
|
||||
|
||||
if msg_ids is not None and not (
|
||||
isinstance(msg_ids, list)
|
||||
and all(isinstance(_, str) for _ in msg_ids)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The 'msg_ids' parameter must be a list of strings or None, "
|
||||
f"but got {type(msg_ids)}.",
|
||||
)
|
||||
|
||||
# First obtain the message ids that belong to this session
|
||||
query = select(self.MessageTable).filter(
|
||||
self.MessageTable.session_id == self.session_id,
|
||||
)
|
||||
|
||||
# Filter by msg_ids if provided
|
||||
if msg_ids is not None:
|
||||
# Convert to composite keys
|
||||
composite_ids = [
|
||||
self._make_message_id(msg_id) for msg_id in msg_ids
|
||||
]
|
||||
query = query.filter(self.MessageTable.id.in_(composite_ids))
|
||||
|
||||
# Filter by old_mark if provided
|
||||
if old_mark is not None:
|
||||
query = query.join(
|
||||
self.MessageMarkTable,
|
||||
self.MessageTable.id == self.MessageMarkTable.msg_id,
|
||||
).filter(self.MessageMarkTable.mark == old_mark)
|
||||
|
||||
# Obtain the message records
|
||||
result = await self.session.execute(query)
|
||||
msg_ids = [str(_.id) for _ in result.scalars().all()]
|
||||
|
||||
# Return early if no messages found
|
||||
if not msg_ids:
|
||||
return 0
|
||||
|
||||
if new_mark:
|
||||
if old_mark:
|
||||
# Replace old_mark with new_mark
|
||||
return await self._replace_message_mark(
|
||||
msg_ids=msg_ids,
|
||||
old_mark=old_mark,
|
||||
new_mark=new_mark,
|
||||
)
|
||||
|
||||
# Add new_mark to the messages
|
||||
return await self._add_message_mark(
|
||||
msg_ids=msg_ids,
|
||||
mark=new_mark,
|
||||
)
|
||||
|
||||
# Remove all marks from the messages
|
||||
return await self._remove_message_mark(
|
||||
msg_ids=msg_ids,
|
||||
old_mark=old_mark,
|
||||
)
|
||||
|
||||
async def _replace_message_mark(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
old_mark: str,
|
||||
new_mark: str,
|
||||
) -> int:
|
||||
"""Replace the old mark with the new mark for the given messages by
|
||||
updating records in the message_mark table.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be updated.
|
||||
old_mark (`str`):
|
||||
The old mark to be replaced.
|
||||
new_mark (`str`):
|
||||
The new mark to be set.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages updated.
|
||||
"""
|
||||
|
||||
await self.session.execute(
|
||||
update(self.MessageMarkTable)
|
||||
.filter(
|
||||
self.MessageMarkTable.msg_id.in_(msg_ids),
|
||||
self.MessageMarkTable.mark == old_mark,
|
||||
)
|
||||
.values(mark=new_mark),
|
||||
)
|
||||
await self.session.commit()
|
||||
return len(msg_ids)
|
||||
|
||||
async def _add_message_mark(self, msg_ids: list[str], mark: str) -> int:
|
||||
"""Mark the messages with the given mark by adding records to the
|
||||
message_mark table.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be marked.
|
||||
mark (`str`):
|
||||
The mark to be added to the messages.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages marked.
|
||||
"""
|
||||
# Use bulk insert for better performance
|
||||
mark_records = [{"msg_id": msg_id, "mark": mark} for msg_id in msg_ids]
|
||||
|
||||
if mark_records:
|
||||
await self.session.run_sync(
|
||||
lambda session: session.bulk_insert_mappings(
|
||||
self.MessageMarkTable,
|
||||
mark_records,
|
||||
),
|
||||
)
|
||||
|
||||
await self.session.commit()
|
||||
return len(msg_ids)
|
||||
|
||||
async def _remove_message_mark(
|
||||
self,
|
||||
msg_ids: list[str],
|
||||
old_mark: str | None,
|
||||
) -> int:
|
||||
"""Remove marks from the messages by deleting records from the
|
||||
message_mark table.
|
||||
|
||||
Args:
|
||||
msg_ids (`list[str]`):
|
||||
The list of message IDs to be unmarked.
|
||||
old_mark (`str | None`):
|
||||
The old mark to be removed. If `None`, all marks will be
|
||||
removed from the messages.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of messages unmarked.
|
||||
"""
|
||||
delete_query = delete(self.MessageMarkTable).filter(
|
||||
self.MessageMarkTable.msg_id.in_(msg_ids),
|
||||
)
|
||||
|
||||
if old_mark:
|
||||
delete_query = delete_query.filter(
|
||||
self.MessageMarkTable.mark == old_mark,
|
||||
)
|
||||
|
||||
await self.session.execute(delete_query)
|
||||
await self.session.commit()
|
||||
return len(msg_ids)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the database session."""
|
||||
if self._db_session and self._db_session.is_active:
|
||||
await self._db_session.close()
|
||||
|
||||
self._db_session = None
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self) -> "AsyncSQLAlchemyMemory":
|
||||
"""Enter the async context manager.
|
||||
|
||||
Returns:
|
||||
`AsyncSQLAlchemyMemory`:
|
||||
The memory instance itself.
|
||||
"""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: Any,
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the session.
|
||||
|
||||
Args:
|
||||
exc_type (`type[BaseException] | None`):
|
||||
The exception type if an exception was raised.
|
||||
exc_value (`BaseException | None`):
|
||||
The exception instance if an exception was raised.
|
||||
traceback (`Any`):
|
||||
The traceback object if an exception was raised.
|
||||
"""
|
||||
await self.close()
|
||||
31
src/agentscope/message/__init__.py
Normal file
31
src/agentscope/message/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The message module in agentscope."""
|
||||
|
||||
from ._message_block import (
|
||||
ContentBlock,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
Base64Source,
|
||||
URLSource,
|
||||
)
|
||||
from ._message_base import Msg
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TextBlock",
|
||||
"ThinkingBlock",
|
||||
"Base64Source",
|
||||
"URLSource",
|
||||
"ImageBlock",
|
||||
"AudioBlock",
|
||||
"VideoBlock",
|
||||
"ToolUseBlock",
|
||||
"ToolResultBlock",
|
||||
"ContentBlock",
|
||||
"Msg",
|
||||
]
|
||||
241
src/agentscope/message/_message_base.py
Normal file
241
src/agentscope/message/_message_base.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The message class in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import Literal, List, overload, Sequence
|
||||
|
||||
import shortuuid
|
||||
|
||||
from ._message_block import (
|
||||
TextBlock,
|
||||
ToolUseBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ContentBlock,
|
||||
VideoBlock,
|
||||
ToolResultBlock,
|
||||
ContentBlockTypes,
|
||||
)
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
|
||||
class Msg:
|
||||
"""The message class in agentscope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
content: str | Sequence[ContentBlock],
|
||||
role: Literal["user", "assistant", "system"],
|
||||
metadata: dict[str, JSONSerializableObject] | None = None,
|
||||
timestamp: str | None = None,
|
||||
invocation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Msg object.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the message sender.
|
||||
content (`str | list[ContentBlock]`):
|
||||
The content of the message.
|
||||
role (`Literal["user", "assistant", "system"]`):
|
||||
The role of the message sender.
|
||||
metadata (`dict[str, JSONSerializableObject] | None`, optional):
|
||||
The metadata of the message, e.g. structured output.
|
||||
timestamp (`str | None`, optional):
|
||||
The created timestamp of the message. If not given, the
|
||||
timestamp will be set automatically.
|
||||
invocation_id (`str | None`, optional):
|
||||
The related API invocation id, if any. This is useful for
|
||||
tracking the message in the context of an API call.
|
||||
"""
|
||||
|
||||
self.name = name
|
||||
|
||||
assert isinstance(
|
||||
content,
|
||||
(list, str),
|
||||
), "The content must be a string or a list of content blocks."
|
||||
|
||||
self.content = content
|
||||
|
||||
assert role in ["user", "assistant", "system"]
|
||||
self.role = role
|
||||
|
||||
self.metadata = metadata or {}
|
||||
|
||||
self.id = shortuuid.uuid()
|
||||
self.timestamp = (
|
||||
timestamp
|
||||
or datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S.%f",
|
||||
)[:-3]
|
||||
)
|
||||
self.invocation_id = invocation_id
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert the message into JSON dict data."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
"metadata": self.metadata,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_data: dict) -> "Msg":
|
||||
"""Load a message object from the given JSON data."""
|
||||
new_obj = cls(
|
||||
name=json_data["name"],
|
||||
content=json_data["content"],
|
||||
role=json_data["role"],
|
||||
metadata=json_data.get("metadata", None),
|
||||
timestamp=json_data.get("timestamp", None),
|
||||
invocation_id=json_data.get("invocation_id", None),
|
||||
)
|
||||
|
||||
new_obj.id = json_data.get("id", new_obj.id)
|
||||
return new_obj
|
||||
|
||||
def has_content_blocks(
|
||||
self,
|
||||
block_type: Literal[
|
||||
"text",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"image",
|
||||
"audio",
|
||||
"video",
|
||||
]
|
||||
| None = None,
|
||||
) -> bool:
|
||||
"""Check if the message has content blocks of the given type.
|
||||
|
||||
Args:
|
||||
block_type (Literal["text", "tool_use", "tool_result", "image", \
|
||||
"audio", "video"] | None, defaults to None):
|
||||
The type of the block to be checked. If `None`, it will
|
||||
check if there are any content blocks.
|
||||
"""
|
||||
return len(self.get_content_blocks(block_type)) > 0
|
||||
|
||||
def get_text_content(self, separator: str = "\n") -> str | None:
|
||||
"""Get the pure text blocks from the message content.
|
||||
|
||||
Args:
|
||||
separator (`str`, defaults to `\n`):
|
||||
The separator to use when concatenating multiple text blocks.
|
||||
Defaults to newline character.
|
||||
|
||||
Returns:
|
||||
`str | None`:
|
||||
The concatenated text content, or `None` if there is no text
|
||||
content.
|
||||
"""
|
||||
if isinstance(self.content, str):
|
||||
return self.content
|
||||
|
||||
gathered_text = []
|
||||
for block in self.content:
|
||||
if block.get("type") == "text":
|
||||
gathered_text.append(block["text"])
|
||||
|
||||
if gathered_text:
|
||||
return separator.join(gathered_text)
|
||||
|
||||
return None
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: Literal["text"],
|
||||
) -> Sequence[TextBlock]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: Literal["tool_use"],
|
||||
) -> Sequence[ToolUseBlock]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: Literal["tool_result"],
|
||||
) -> Sequence[ToolResultBlock]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: Literal["image"],
|
||||
) -> Sequence[ImageBlock]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: Literal["audio"],
|
||||
) -> Sequence[AudioBlock]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: Literal["video"],
|
||||
) -> Sequence[VideoBlock]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: None = None,
|
||||
) -> Sequence[ContentBlock]:
|
||||
...
|
||||
|
||||
def get_content_blocks(
|
||||
self,
|
||||
block_type: ContentBlockTypes | List[ContentBlockTypes] | None = None,
|
||||
) -> Sequence[ContentBlock]:
|
||||
"""Get the content in block format. If the content is a string,
|
||||
it will be converted to a text block.
|
||||
|
||||
Args:
|
||||
block_type (`ContentBlockTypes | List[ContentBlockTypes] | None`, \
|
||||
optional):
|
||||
The type of the block to be extracted. If `None`, all blocks
|
||||
will be returned.
|
||||
|
||||
Returns:
|
||||
`List[ContentBlock]`:
|
||||
The content blocks.
|
||||
"""
|
||||
blocks = []
|
||||
if isinstance(self.content, str):
|
||||
blocks.append(
|
||||
TextBlock(type="text", text=self.content),
|
||||
)
|
||||
else:
|
||||
blocks = self.content or []
|
||||
|
||||
if isinstance(block_type, str):
|
||||
blocks = [_ for _ in blocks if _["type"] == block_type]
|
||||
|
||||
elif isinstance(block_type, list):
|
||||
blocks = [_ for _ in blocks if _["type"] in block_type]
|
||||
|
||||
return blocks
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Get the string representation of the message."""
|
||||
return (
|
||||
f"Msg(id='{self.id}', "
|
||||
f"name='{self.name}', "
|
||||
f"content={repr(self.content)}, "
|
||||
f"role='{self.role}', "
|
||||
f"metadata={repr(self.metadata)}, "
|
||||
f"timestamp='{self.timestamp}', "
|
||||
f"invocation_id='{self.invocation_id}')"
|
||||
)
|
||||
128
src/agentscope/message/_message_block.py
Normal file
128
src/agentscope/message/_message_block.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=R0901
|
||||
"""The content blocks of messages"""
|
||||
|
||||
from typing import Literal, List
|
||||
from typing_extensions import TypedDict, Required
|
||||
|
||||
|
||||
class TextBlock(TypedDict, total=False):
|
||||
"""The text block."""
|
||||
|
||||
type: Required[Literal["text"]]
|
||||
"""The type of the block"""
|
||||
text: str
|
||||
"""The text content"""
|
||||
|
||||
|
||||
class ThinkingBlock(TypedDict, total=False):
|
||||
"""The thinking block."""
|
||||
|
||||
type: Required[Literal["thinking"]]
|
||||
"""The type of the block"""
|
||||
thinking: str
|
||||
|
||||
|
||||
class Base64Source(TypedDict, total=False):
|
||||
"""The base64 source"""
|
||||
|
||||
type: Required[Literal["base64"]]
|
||||
"""The type of the src, must be `base64`"""
|
||||
|
||||
media_type: Required[str]
|
||||
"""The media type of the data, e.g. `image/jpeg` or `audio/mpeg`"""
|
||||
|
||||
data: Required[str]
|
||||
"""The base64 data, in format of RFC 2397"""
|
||||
|
||||
|
||||
class URLSource(TypedDict, total=False):
|
||||
"""The URL source"""
|
||||
|
||||
type: Required[Literal["url"]]
|
||||
"""The type of the src, must be `url`"""
|
||||
|
||||
url: Required[str]
|
||||
"""The URL of the image or audio"""
|
||||
|
||||
|
||||
class ImageBlock(TypedDict, total=False):
|
||||
"""The image block"""
|
||||
|
||||
type: Required[Literal["image"]]
|
||||
"""The type of the block"""
|
||||
|
||||
source: Required[Base64Source | URLSource]
|
||||
"""The src of the image"""
|
||||
|
||||
|
||||
class AudioBlock(TypedDict, total=False):
|
||||
"""The audio block"""
|
||||
|
||||
type: Required[Literal["audio"]]
|
||||
"""The type of the block"""
|
||||
|
||||
source: Required[Base64Source | URLSource]
|
||||
"""The src of the audio"""
|
||||
|
||||
|
||||
class VideoBlock(TypedDict, total=False):
|
||||
"""The video block"""
|
||||
|
||||
type: Required[Literal["video"]]
|
||||
"""The type of the block"""
|
||||
|
||||
source: Required[Base64Source | URLSource]
|
||||
"""The src of the audio"""
|
||||
|
||||
|
||||
class ToolUseBlock(TypedDict, total=False):
|
||||
"""The tool use block."""
|
||||
|
||||
type: Required[Literal["tool_use"]]
|
||||
"""The type of the block, must be `tool_use`"""
|
||||
id: Required[str]
|
||||
"""The identity of the tool call"""
|
||||
name: Required[str]
|
||||
"""The name of the tool"""
|
||||
input: Required[dict[str, object]]
|
||||
"""The input of the tool"""
|
||||
raw_input: str
|
||||
"""The raw string input of the tool from the model API"""
|
||||
|
||||
|
||||
class ToolResultBlock(TypedDict, total=False):
|
||||
"""The tool result block."""
|
||||
|
||||
type: Required[Literal["tool_result"]]
|
||||
"""The type of the block"""
|
||||
id: Required[str]
|
||||
"""The identity of the tool call result"""
|
||||
output: Required[
|
||||
str | List[TextBlock | ImageBlock | AudioBlock | VideoBlock]
|
||||
]
|
||||
"""The output of the tool function"""
|
||||
name: Required[str]
|
||||
"""The name of the tool function"""
|
||||
|
||||
|
||||
# The content block
|
||||
ContentBlock = (
|
||||
ToolUseBlock
|
||||
| ToolResultBlock
|
||||
| TextBlock
|
||||
| ThinkingBlock
|
||||
| ImageBlock
|
||||
| AudioBlock
|
||||
| VideoBlock
|
||||
)
|
||||
|
||||
ContentBlockTypes = Literal[
|
||||
"text",
|
||||
"thinking",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"image",
|
||||
"audio",
|
||||
"video",
|
||||
]
|
||||
22
src/agentscope/model/__init__.py
Normal file
22
src/agentscope/model/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model module."""
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._dashscope_model import DashScopeChatModel
|
||||
from ._openai_model import OpenAIChatModel
|
||||
from ._anthropic_model import AnthropicChatModel
|
||||
from ._ollama_model import OllamaChatModel
|
||||
from ._gemini_model import GeminiChatModel
|
||||
from ._trinity_model import TrinityChatModel
|
||||
|
||||
__all__ = [
|
||||
"ChatModelBase",
|
||||
"ChatResponse",
|
||||
"DashScopeChatModel",
|
||||
"OpenAIChatModel",
|
||||
"AnthropicChatModel",
|
||||
"OllamaChatModel",
|
||||
"GeminiChatModel",
|
||||
"TrinityChatModel",
|
||||
]
|
||||
590
src/agentscope/model/_anthropic_model.py
Normal file
590
src/agentscope/model/_anthropic_model.py
Normal file
@@ -0,0 +1,590 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""The Anthropic API model classes."""
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import (
|
||||
_json_loads_with_repair,
|
||||
_create_tool_from_base_model,
|
||||
)
|
||||
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types._json import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.message import Message
|
||||
from anthropic import AsyncStream
|
||||
else:
|
||||
Message = "anthropic.types.message.Message"
|
||||
AsyncStream = "anthropic.AsyncStream"
|
||||
|
||||
|
||||
class AnthropicChatModel(ChatModelBase):
|
||||
"""The Anthropic model wrapper for AgentScope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str | None = None,
|
||||
max_tokens: int = 2048,
|
||||
stream: bool = True,
|
||||
thinking: dict | None = None,
|
||||
stream_tool_parsing: bool = True,
|
||||
client_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Anthropic chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The model names.
|
||||
api_key (`str`):
|
||||
The anthropic API key.
|
||||
stream (`bool`):
|
||||
The streaming output or not
|
||||
max_tokens (`int`):
|
||||
Limit the maximum token count the model can generate.
|
||||
thinking (`dict | None`, default `None`):
|
||||
Configuration for Claude's internal reasoning process.
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of thinking
|
||||
|
||||
{
|
||||
"type": "enabled" | "disabled",
|
||||
"budget_tokens": 1024
|
||||
}
|
||||
|
||||
stream_tool_parsing (`bool`, default to `True`):
|
||||
Whether to parse incomplete tool use JSON during streaming
|
||||
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
|
||||
is repaired to valid dicts ({"a": "x"}) in real-time for
|
||||
immediate tool function input. Otherwise, the input field
|
||||
remains {} until the final chunk arrives.
|
||||
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments to initialize the Anthropic client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in Anthropic API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
"""
|
||||
|
||||
# Handle deprecated client_args parameter from kwargs
|
||||
client_args = kwargs.pop("client_args", None)
|
||||
if client_args is not None and client_kwargs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both 'client_args' and 'client_kwargs'. "
|
||||
"Please use only 'client_kwargs' (client_args is deprecated).",
|
||||
)
|
||||
|
||||
if client_args is not None:
|
||||
logger.warning(
|
||||
"The parameter 'client_args' is deprecated and will be "
|
||||
"removed in a future version. Please use 'client_kwargs' "
|
||||
"instead. Automatically converting 'client_args' to "
|
||||
"'client_kwargs'.",
|
||||
)
|
||||
client_kwargs = client_args
|
||||
|
||||
if kwargs:
|
||||
logger.warning(
|
||||
"Unknown keyword arguments: %s. These will be ignored.",
|
||||
list(kwargs.keys()),
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package by running "
|
||||
"`pip install anthropic`.",
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=api_key,
|
||||
**(client_kwargs or {}),
|
||||
)
|
||||
self.max_tokens = max_tokens
|
||||
self.thinking = thinking
|
||||
self.stream_tool_parsing = stream_tool_parsing
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**generate_kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from Anthropic chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that in format of:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of tools JSON schemas
|
||||
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "xxx",
|
||||
"description": "xxx",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "string",
|
||||
"description": "..."
|
||||
},
|
||||
# Add more parameters as needed
|
||||
},
|
||||
"required": ["param1"]
|
||||
}
|
||||
},
|
||||
# More schemas here
|
||||
]
|
||||
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
**generate_kwargs (`Any`):
|
||||
The keyword arguments for Anthropic chat completions API,
|
||||
e.g. `temperature`, `top_p`, etc. Please
|
||||
refer to the Anthropic API documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the Anthropic chat completions API."""
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**generate_kwargs,
|
||||
}
|
||||
if self.thinking and "thinking" not in kwargs:
|
||||
kwargs["thinking"] = self.thinking
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
# Handle deprecated "any" option with warning
|
||||
if tool_choice == "any":
|
||||
warnings.warn(
|
||||
'"any" is deprecated and will be removed in a future '
|
||||
"version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tool_choice = "required"
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
format_tool = _create_tool_from_base_model(structured_model)
|
||||
kwargs["tools"] = self._format_tools_json_schemas(
|
||||
[format_tool],
|
||||
)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(
|
||||
format_tool["function"]["name"],
|
||||
)
|
||||
|
||||
# Extract the system message
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["system"] = messages[0]["content"]
|
||||
messages = messages[1:]
|
||||
|
||||
kwargs["messages"] = messages
|
||||
|
||||
start_datetime = datetime.now()
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_anthropic_stream_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
parsed_response = await self._parse_anthropic_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_anthropic_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Message,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an Anthropic Message object, extract the content blocks and
|
||||
usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`Message`):
|
||||
Anthropic Message object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = []
|
||||
metadata = None
|
||||
|
||||
if hasattr(response, "content") and response.content:
|
||||
for content_block in response.content:
|
||||
if (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "thinking"
|
||||
):
|
||||
thinking_block = ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=content_block.thinking,
|
||||
)
|
||||
thinking_block["signature"] = content_block.signature
|
||||
content_blocks.append(thinking_block)
|
||||
|
||||
elif (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "text"
|
||||
):
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content_block.text,
|
||||
),
|
||||
)
|
||||
|
||||
elif (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "tool_use"
|
||||
):
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=content_block.id,
|
||||
name=content_block.name,
|
||||
input=content_block.input,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = content_block.input
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_anthropic_stream_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncStream,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an Anthropic streaming response, extract the content blocks
|
||||
and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncStream`):
|
||||
Anthropic AsyncStream object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in
|
||||
the streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
|
||||
usage = None
|
||||
text_buffer = ""
|
||||
thinking_buffer = ""
|
||||
thinking_signature = ""
|
||||
tool_calls = OrderedDict()
|
||||
tool_call_buffers = {}
|
||||
last_input_objs = {} # Store last input_obj for each tool_call
|
||||
res = None
|
||||
metadata = None
|
||||
|
||||
# Record the last yielded content to parse the tools' input
|
||||
last_content = None
|
||||
|
||||
async for event in response:
|
||||
content_changed = False
|
||||
thinking_changed = False
|
||||
|
||||
if event.type == "message_start":
|
||||
message = event.message
|
||||
if message.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=message.usage.input_tokens,
|
||||
output_tokens=getattr(
|
||||
message.usage,
|
||||
"output_tokens",
|
||||
0,
|
||||
),
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
block_index = event.index
|
||||
tool_block = event.content_block
|
||||
tool_calls[block_index] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_block.id,
|
||||
"name": tool_block.name,
|
||||
"input": "",
|
||||
}
|
||||
tool_call_buffers[block_index] = ""
|
||||
content_changed = True
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
block_index = event.index
|
||||
delta = event.delta
|
||||
if delta.type == "text_delta":
|
||||
text_buffer += delta.text
|
||||
content_changed = True
|
||||
elif delta.type == "thinking_delta":
|
||||
thinking_buffer += delta.thinking
|
||||
thinking_changed = True
|
||||
elif delta.type == "signature_delta":
|
||||
thinking_signature = delta.signature
|
||||
elif (
|
||||
delta.type == "input_json_delta"
|
||||
and block_index in tool_calls
|
||||
):
|
||||
tool_call_buffers[block_index] += delta.partial_json or ""
|
||||
tool_calls[block_index]["input"] = tool_call_buffers[
|
||||
block_index
|
||||
]
|
||||
content_changed = True
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if event.usage and usage:
|
||||
usage.output_tokens = event.usage.output_tokens
|
||||
|
||||
if (thinking_changed or content_changed) and usage:
|
||||
contents: list = []
|
||||
if thinking_buffer:
|
||||
thinking_block = ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking_buffer,
|
||||
)
|
||||
thinking_block["signature"] = thinking_signature
|
||||
contents.append(thinking_block)
|
||||
if text_buffer:
|
||||
contents.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text_buffer,
|
||||
),
|
||||
)
|
||||
for block_index, tool_call in tool_calls.items():
|
||||
input_str = tool_call["input"]
|
||||
tool_id = tool_call["id"]
|
||||
|
||||
# If parsing the tool input in streaming mode
|
||||
if self.stream_tool_parsing:
|
||||
repaired_input = _json_loads_with_repair(
|
||||
input_str or "{}",
|
||||
)
|
||||
# If the new repaired input is shorter than one in the
|
||||
# last chunk, use the last one to avoid regression
|
||||
last_input = last_input_objs.get(tool_id, {})
|
||||
if len(json.dumps(last_input)) > len(
|
||||
json.dumps(repaired_input),
|
||||
):
|
||||
repaired_input = last_input
|
||||
last_input_objs[tool_id] = repaired_input
|
||||
|
||||
else:
|
||||
repaired_input = {}
|
||||
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
input=repaired_input,
|
||||
raw_input=input_str,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = repaired_input
|
||||
|
||||
if contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
last_content = copy.deepcopy(contents)
|
||||
|
||||
# If stream_tool_parsing is False, yield last contents
|
||||
if not self.stream_tool_parsing and last_content and tool_calls:
|
||||
metadata = None
|
||||
# Update tool use blocks in last_contents inplace
|
||||
for block in last_content:
|
||||
if block.get("type") == "tool_use":
|
||||
block["input"] = input_obj = _json_loads_with_repair(
|
||||
block.get("raw_input") or "{}",
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_obj
|
||||
|
||||
yield ChatResponse(
|
||||
content=last_content,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the JSON schemas of the tool functions to the format that
|
||||
Anthropic API expects."""
|
||||
formatted_schemas = []
|
||||
for schema in schemas:
|
||||
assert (
|
||||
"function" in schema
|
||||
), f"Invalid schema: {schema}, expect key 'function'."
|
||||
|
||||
assert "name" in schema["function"], (
|
||||
f"Invalid schema: {schema}, "
|
||||
"expect key 'name' in 'function' field."
|
||||
)
|
||||
|
||||
formatted_schemas.append(
|
||||
{
|
||||
"name": schema["function"]["name"],
|
||||
"description": schema["function"].get("description", ""),
|
||||
"input_schema": schema["function"].get("parameters", {}),
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None,
|
||||
) -> dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
type_mapping = {
|
||||
"auto": {"type": "auto"},
|
||||
"none": {"type": "none"},
|
||||
"required": {"type": "any"},
|
||||
}
|
||||
if tool_choice in type_mapping:
|
||||
return type_mapping[tool_choice]
|
||||
|
||||
return {"type": "tool", "name": tool_choice}
|
||||
632
src/agentscope/model/_dashscope_model.py
Normal file
632
src/agentscope/model/_dashscope_model.py
Normal file
@@ -0,0 +1,632 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope API model classes."""
|
||||
import copy
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from aioitertools import iter as giter
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._model_usage import ChatUsage
|
||||
from .._utils._common import (
|
||||
_json_loads_with_repair,
|
||||
_create_tool_from_base_model,
|
||||
)
|
||||
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
from .._logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
MultiModalConversationResponse,
|
||||
)
|
||||
else:
|
||||
GenerationResponse = (
|
||||
"dashscope.api_entities.dashscope_response.GenerationResponse"
|
||||
)
|
||||
MultiModalConversationResponse = (
|
||||
"dashscope.api_entities.dashscope_response."
|
||||
"MultiModalConversationResponse"
|
||||
)
|
||||
|
||||
|
||||
class DashScopeChatModel(ChatModelBase):
|
||||
"""The DashScope chat model class, which unifies the Generation and
|
||||
MultimodalConversation APIs into one method.
|
||||
|
||||
This class provides a unified interface for DashScope API by automatically
|
||||
selecting between text-only (Generation API) and multimodal
|
||||
(MultiModalConversation API) endpoints. The `multimodality` parameter
|
||||
allows explicit control over API selection:
|
||||
|
||||
- When `multimodality=True`: Forces use of MultiModalConversation API
|
||||
for handling images, videos, and other multimodal inputs
|
||||
- When `multimodality=False`: Forces use of Generation API for
|
||||
text-only processing
|
||||
- When `multimodality=None` (default): Automatically selects the API
|
||||
based on model name (e.g., models with "-vl" suffix or starting
|
||||
with "qvq" will use MultiModalConversation API)
|
||||
|
||||
This design enables seamless switching between text and multimodal
|
||||
models without changing code structure, making it easier to work with
|
||||
DashScope's diverse model offerings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
stream: bool = True,
|
||||
enable_thinking: bool | None = None,
|
||||
multimodality: bool | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
base_http_api_url: str | None = None,
|
||||
stream_tool_parsing: bool = True,
|
||||
**_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the DashScope chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The model names.
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
stream (`bool`):
|
||||
The streaming output or not
|
||||
enable_thinking (`bool | None`, optional):
|
||||
Enable thinking or not, only support Qwen3, QwQ, DeepSeek-R1.
|
||||
Refer to `DashScope documentation
|
||||
<https://help.aliyun.com/zh/model-studio/deep-thinking>`_
|
||||
for more details.
|
||||
multimodality (`bool | None`, optional):
|
||||
Whether to use multimodal conversation API. If `True`,
|
||||
it will use `dashscope.MultiModalConversation.call`
|
||||
to process multimodal inputs such as images and text. If
|
||||
`False`, it will use
|
||||
`dashscope.aigc.generation.AioGeneration.call` to process
|
||||
text inputs. If `None` (default), the choice is based on
|
||||
the model name.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in DashScope API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
base_http_api_url (`str | None`, optional):
|
||||
The base URL for DashScope API requests. If not provided,
|
||||
the default base URL from the DashScope SDK will be used.
|
||||
stream_tool_parsing (`bool`, default to `True`):
|
||||
Whether to parse incomplete tool use JSON in streaming mode
|
||||
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
|
||||
is repaired to valid dicts (`{"a": "x"}`) in real-time for
|
||||
immediate tool function input. Otherwise, the input field
|
||||
remains {} until the final chunk arrives.
|
||||
**_kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
"""
|
||||
if enable_thinking and not stream:
|
||||
logger.info(
|
||||
"In DashScope API, `stream` must be True when "
|
||||
"`enable_thinking` is True. ",
|
||||
)
|
||||
stream = True
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.api_key = api_key
|
||||
self.enable_thinking = enable_thinking
|
||||
self.multimodality = multimodality
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
self.stream_tool_parsing = stream_tool_parsing
|
||||
|
||||
if base_http_api_url is not None:
|
||||
import dashscope
|
||||
|
||||
dashscope.base_http_api_url = base_http_api_url
|
||||
|
||||
# Load headers from environment variable if exists
|
||||
headers = os.getenv("DASHSCOPE_API_HEADERS")
|
||||
if headers:
|
||||
try:
|
||||
headers = json.loads(str(headers))
|
||||
if not isinstance(headers, dict):
|
||||
raise json.JSONDecodeError("", "", 0)
|
||||
|
||||
if self.generate_kwargs.get("headers"):
|
||||
headers.update(self.generate_kwargs["headers"])
|
||||
|
||||
self.generate_kwargs["headers"] = headers
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse DASHSCOPE_API_HEADERS environment "
|
||||
"variable as JSON. It should be a JSON object.",
|
||||
)
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from the dashscope
|
||||
Generation/MultimodalConversation API by the given arguments.
|
||||
|
||||
.. note:: We unify the dashscope generation and multimodal conversation
|
||||
APIs into one method, since they support similar arguments and share
|
||||
the same functionality.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required.
|
||||
tools (`list[dict] | None`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool name.
|
||||
Note: DashScope API only supports "auto" and "none", so
|
||||
"required" will be converted to "auto".
|
||||
For more details, please refer to
|
||||
https://help.aliyun.com/zh/model-studio/qwen-function-calling
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for DashScope chat completions API,
|
||||
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
|
||||
refer to `DashScope documentation
|
||||
<https://help.aliyun.com/zh/dashscope/developer-reference/api-details>`_
|
||||
for more detailed arguments.
|
||||
"""
|
||||
import dashscope
|
||||
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**kwargs,
|
||||
"result_format": "message",
|
||||
# In agentscope, the `incremental_output` must be `True` when
|
||||
# `self.stream` is True
|
||||
"incremental_output": self.stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
# Handle deprecated "any" option with warning
|
||||
if tool_choice in ["any", "required"]:
|
||||
warnings.warn(
|
||||
f"'{tool_choice}' is not supported by DashScope API. "
|
||||
"It will be converted to 'auto'.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tool_choice = "auto"
|
||||
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if (
|
||||
self.enable_thinking is not None
|
||||
and "enable_thinking" not in kwargs
|
||||
):
|
||||
kwargs["enable_thinking"] = self.enable_thinking
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
format_tool = _create_tool_from_base_model(structured_model)
|
||||
kwargs["tools"] = self._format_tools_json_schemas(
|
||||
[format_tool],
|
||||
)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(
|
||||
format_tool["function"]["name"],
|
||||
)
|
||||
|
||||
start_datetime = datetime.now()
|
||||
if self.multimodality or (
|
||||
self.multimodality is None
|
||||
and (
|
||||
self.model_name.startswith(
|
||||
"qvq",
|
||||
)
|
||||
or "-vl" in self.model_name
|
||||
)
|
||||
):
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
response = await dashscope.aigc.generation.AioGeneration.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_dashscope_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
parsed_response = await self._parse_dashscope_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
async def _parse_dashscope_stream_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Union[
|
||||
AsyncGenerator[GenerationResponse, None],
|
||||
Generator[MultiModalConversationResponse, None, None],
|
||||
],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, Any]:
|
||||
"""Given a DashScope streaming response generator, extract the content
|
||||
blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (
|
||||
`Union[AsyncGenerator[GenerationResponse, None], Generator[ \
|
||||
MultiModalConversationResponse, None, None]]`
|
||||
):
|
||||
DashScope streaming response generator (GenerationResponse or
|
||||
MultiModalConversationResponse) to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[ChatResponse, Any]:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
acc_content, acc_thinking_content = "", ""
|
||||
acc_tool_calls = collections.defaultdict(dict)
|
||||
last_input_objs = {} # Store last input_obj for each tool_call
|
||||
metadata = None
|
||||
last_content = None
|
||||
usage = None
|
||||
|
||||
async for chunk in giter(response):
|
||||
if chunk.status_code != HTTPStatus.OK:
|
||||
raise RuntimeError(
|
||||
f"Failed to get response from _ API: {chunk}",
|
||||
)
|
||||
|
||||
message = chunk.output.choices[0].message
|
||||
|
||||
# Update reasoning content
|
||||
if isinstance(message.get("reasoning_content"), str):
|
||||
acc_thinking_content += message["reasoning_content"]
|
||||
|
||||
# Update text content
|
||||
if isinstance(message.content, str):
|
||||
acc_content += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
acc_content += item["text"]
|
||||
|
||||
# Update tool calls
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
index = tool_call.get("index", 0)
|
||||
|
||||
if "id" in tool_call and tool_call["id"] != acc_tool_calls[
|
||||
index
|
||||
].get("id"):
|
||||
acc_tool_calls[index]["id"] = (
|
||||
acc_tool_calls[index].get("id", "") + tool_call["id"]
|
||||
)
|
||||
|
||||
if "function" in tool_call:
|
||||
func = tool_call["function"]
|
||||
if "name" in func:
|
||||
acc_tool_calls[index]["name"] = (
|
||||
acc_tool_calls[index].get("name", "")
|
||||
+ func["name"]
|
||||
)
|
||||
|
||||
if "arguments" in func:
|
||||
acc_tool_calls[index]["arguments"] = (
|
||||
acc_tool_calls[index].get("arguments", "")
|
||||
+ func["arguments"]
|
||||
)
|
||||
|
||||
# Build content blocks (always include thinking and text)
|
||||
content_blocks: list[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
|
||||
if acc_thinking_content:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=acc_thinking_content,
|
||||
),
|
||||
)
|
||||
|
||||
if acc_content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=acc_content,
|
||||
),
|
||||
)
|
||||
|
||||
for tool_call in acc_tool_calls.values():
|
||||
# Only add intermediate tool use blocks if
|
||||
# stream_tool_parsing is True
|
||||
tool_id = tool_call.get("id", "")
|
||||
input_str = tool_call.get("arguments")
|
||||
|
||||
# If parsing the tool input in streaming mode
|
||||
if self.stream_tool_parsing:
|
||||
repaired_input = _json_loads_with_repair(
|
||||
input_str or "{}",
|
||||
)
|
||||
# If the new repaired input is shorter than one in the last
|
||||
# chunk, use the last one to avoid regression
|
||||
last_input = last_input_objs.get(tool_id, {})
|
||||
if len(json.dumps(last_input)) > len(
|
||||
json.dumps(repaired_input),
|
||||
):
|
||||
repaired_input = last_input
|
||||
last_input_objs[tool_id] = repaired_input
|
||||
|
||||
else:
|
||||
# Otherwise, keep input as empty dict until the final chunk
|
||||
repaired_input = {}
|
||||
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_id,
|
||||
name=tool_call.get("name", ""),
|
||||
input=repaired_input,
|
||||
raw_input=input_str,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = repaired_input
|
||||
|
||||
if chunk.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage.input_tokens,
|
||||
output_tokens=chunk.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
metadata=chunk.usage,
|
||||
)
|
||||
|
||||
if content_blocks:
|
||||
parsed_chunk = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield parsed_chunk
|
||||
last_content = copy.deepcopy(content_blocks)
|
||||
|
||||
# If stream_tool_parsing is False, we need to parse the final tool
|
||||
# use inputs here
|
||||
if not self.stream_tool_parsing and last_content and acc_tool_calls:
|
||||
metadata = None
|
||||
# Update tool use blocks in last_contents inplace
|
||||
for block in last_content:
|
||||
if block.get("type") == "tool_use":
|
||||
block["input"] = input_obj = _json_loads_with_repair(
|
||||
str(block.get("raw_input") or "{}"),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_obj
|
||||
|
||||
yield ChatResponse(
|
||||
content=last_content,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def _parse_dashscope_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Union[
|
||||
GenerationResponse,
|
||||
MultiModalConversationResponse,
|
||||
],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given a DashScope GenerationResponse object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (
|
||||
`Union[GenerationResponse, MultiModalConversationResponse]`
|
||||
):
|
||||
Dashscope GenerationResponse | MultiModalConversationResponse
|
||||
object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
# Collect the content blocks from the response.
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(response)
|
||||
|
||||
content_blocks: List[TextBlock | ToolUseBlock] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
message = response.output.choices[0].message
|
||||
content = message.get("content")
|
||||
|
||||
if response.output.choices[0].message.get("content") not in [
|
||||
None,
|
||||
"",
|
||||
[],
|
||||
]:
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=item["text"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content,
|
||||
),
|
||||
)
|
||||
|
||||
if message.get("tool_calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
input_ = _json_loads_with_repair(
|
||||
tool_call["function"].get(
|
||||
"arguments",
|
||||
"{}",
|
||||
)
|
||||
or "{}",
|
||||
)
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
name=tool_call["function"]["name"],
|
||||
input=input_,
|
||||
id=tool_call["id"],
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_
|
||||
|
||||
# Usage information
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
metadata=response.usage,
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schema into required format for DashScope API.
|
||||
|
||||
Args:
|
||||
schemas (`dict[str, dict[str, Any]]`):
|
||||
The tools JSON schemas.
|
||||
"""
|
||||
# Check schemas format
|
||||
for value in schemas:
|
||||
if (
|
||||
not isinstance(value, dict)
|
||||
or "type" not in value
|
||||
or value["type"] != "function"
|
||||
or "function" not in value
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each schema must be a dict with 'type' as 'function' "
|
||||
f"and 'function' key, got {value}",
|
||||
)
|
||||
|
||||
return schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None,
|
||||
) -> str | dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model. For more
|
||||
details, please refer to
|
||||
https://help.aliyun.com/zh/model-studio/qwen-function-calling
|
||||
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
if tool_choice in ["auto", "none"]:
|
||||
return tool_choice
|
||||
if tool_choice == "required":
|
||||
return "auto"
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user