Files
tw2/src/agentscope/agent/_agent_meta.py
jimi a842f1861f
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
chore: initial import of standalone agentscope project
2026-03-02 18:21:40 +08:00

181 lines
5.6 KiB
Python

# -*- coding: utf-8 -*-
"""The metaclass for agents in agentscope."""
import inspect
from copy import deepcopy
from functools import wraps
from typing import (
Any,
Dict,
TYPE_CHECKING,
Callable,
)
from .._utils._common import _execute_async_or_sync_func
if TYPE_CHECKING:
from ._agent_base import AgentBase
else:
AgentBase = "AgentBase"
def _normalize_to_kwargs(
func: Callable,
self: Any,
*args: Any,
**kwargs: Any,
) -> dict:
"""Normalize the provided positional and keyword arguments into a
keyword arguments dictionary that matches the function signature."""
sig = inspect.signature(func)
try:
# Bind the provided arguments to the function signature
bound = sig.bind(self, *args, **kwargs)
# Apply the default values for parameters
bound.apply_defaults()
# Return the arguments in a dictionary format
res = dict(bound.arguments)
res.pop("self")
return res
except TypeError as e:
# If failed to bind, we raise a TypeError with more context
param_names = list(sig.parameters.keys())
provided_args = len(args)
provided_kwargs = list(kwargs.keys())
raise TypeError(
f"Failed to bind parameters for function '{func.__name__}': {e}\n"
f"Expected parameters: {param_names}\n"
f"Provided {provided_args} positional args and kwargs: "
f"{provided_kwargs}",
) from e
def _wrap_with_hooks(
original_func: Callable,
) -> Callable:
"""A decorator to wrap the original async function with pre- and post-hooks
Args:
original_func (`Callable`):
The original async function to be wrapped with hooks.
"""
func_name = original_func.__name__.replace("_", "")
@wraps(original_func)
async def async_wrapper(
self: AgentBase,
*args: Any,
**kwargs: Any,
) -> Any:
"""The wrapped function, which call the pre- and post-hooks before and
after the original function."""
# Unify all positional and keyword arguments into a keyword arguments
normalized_kwargs = _normalize_to_kwargs(
original_func,
self,
*args,
**kwargs,
)
current_normalized_kwargs = normalized_kwargs
assert (
hasattr(self, f"_instance_pre_{func_name}_hooks")
and hasattr(self, f"_instance_post_{func_name}_hooks")
and hasattr(self.__class__, f"_class_pre_{func_name}_hooks")
and hasattr(self.__class__, f"_class_post_{func_name}_hooks")
), f"Hooks for {func_name} not found in {self.__class__.__name__}"
# pre-hooks
pre_hooks = list(
getattr(self, f"_instance_pre_{func_name}_hooks").values(),
) + list(
getattr(self, f"_class_pre_{func_name}_hooks").values(),
)
for pre_hook in pre_hooks:
modified_keywords = await _execute_async_or_sync_func(
pre_hook,
self,
deepcopy(current_normalized_kwargs),
)
if modified_keywords is not None:
assert isinstance(modified_keywords, dict), (
f"Pre-hook must return a dict of keyword arguments, rather"
f" than {type(modified_keywords)} from hook "
f"{pre_hook.__name__}"
)
current_normalized_kwargs = modified_keywords
# original function
# handle positional and keyword arguments specifically
args = current_normalized_kwargs.get("args", [])
kwargs = current_normalized_kwargs.get("kwargs", {})
others = {
k: v
for k, v in current_normalized_kwargs.items()
if k not in ["args", "kwargs"]
}
current_output = await original_func(
self,
*args,
**others,
**kwargs,
)
# post_hooks
post_hooks = list(
getattr(self, f"_instance_post_{func_name}_hooks").values(),
) + list(
getattr(self, f"_class_post_{func_name}_hooks").values(),
)
for post_hook in post_hooks:
modified_output = await _execute_async_or_sync_func(
post_hook,
self,
deepcopy(current_normalized_kwargs),
deepcopy(current_output),
)
if modified_output is not None:
current_output = modified_output
return current_output
return async_wrapper
class _AgentMeta(type):
"""The agent metaclass that wraps the agent's reply, observe and print
functions with pre- and post-hooks."""
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the agent's functions with hooks."""
for func_name in [
"reply",
"print",
"observe",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
return super().__new__(mcs, name, bases, attrs)
class _ReActAgentMeta(_AgentMeta):
"""The ReAct metaclass that adds pre- and post-hooks for the _reasoning
and _acting functions."""
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the ReAct agent's _reasoning and _acting functions with
hooks."""
for func_name in [
"_reasoning",
"_acting",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
return super().__new__(mcs, name, bases, attrs)