chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
22
src/agentscope/model/__init__.py
Normal file
22
src/agentscope/model/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model module."""
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._dashscope_model import DashScopeChatModel
|
||||
from ._openai_model import OpenAIChatModel
|
||||
from ._anthropic_model import AnthropicChatModel
|
||||
from ._ollama_model import OllamaChatModel
|
||||
from ._gemini_model import GeminiChatModel
|
||||
from ._trinity_model import TrinityChatModel
|
||||
|
||||
__all__ = [
|
||||
"ChatModelBase",
|
||||
"ChatResponse",
|
||||
"DashScopeChatModel",
|
||||
"OpenAIChatModel",
|
||||
"AnthropicChatModel",
|
||||
"OllamaChatModel",
|
||||
"GeminiChatModel",
|
||||
"TrinityChatModel",
|
||||
]
|
||||
590
src/agentscope/model/_anthropic_model.py
Normal file
590
src/agentscope/model/_anthropic_model.py
Normal file
@@ -0,0 +1,590 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""The Anthropic API model classes."""
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import (
|
||||
_json_loads_with_repair,
|
||||
_create_tool_from_base_model,
|
||||
)
|
||||
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types._json import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.message import Message
|
||||
from anthropic import AsyncStream
|
||||
else:
|
||||
Message = "anthropic.types.message.Message"
|
||||
AsyncStream = "anthropic.AsyncStream"
|
||||
|
||||
|
||||
class AnthropicChatModel(ChatModelBase):
|
||||
"""The Anthropic model wrapper for AgentScope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str | None = None,
|
||||
max_tokens: int = 2048,
|
||||
stream: bool = True,
|
||||
thinking: dict | None = None,
|
||||
stream_tool_parsing: bool = True,
|
||||
client_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Anthropic chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The model names.
|
||||
api_key (`str`):
|
||||
The anthropic API key.
|
||||
stream (`bool`):
|
||||
The streaming output or not
|
||||
max_tokens (`int`):
|
||||
Limit the maximum token count the model can generate.
|
||||
thinking (`dict | None`, default `None`):
|
||||
Configuration for Claude's internal reasoning process.
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of thinking
|
||||
|
||||
{
|
||||
"type": "enabled" | "disabled",
|
||||
"budget_tokens": 1024
|
||||
}
|
||||
|
||||
stream_tool_parsing (`bool`, default to `True`):
|
||||
Whether to parse incomplete tool use JSON during streaming
|
||||
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
|
||||
is repaired to valid dicts ({"a": "x"}) in real-time for
|
||||
immediate tool function input. Otherwise, the input field
|
||||
remains {} until the final chunk arrives.
|
||||
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments to initialize the Anthropic client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in Anthropic API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
"""
|
||||
|
||||
# Handle deprecated client_args parameter from kwargs
|
||||
client_args = kwargs.pop("client_args", None)
|
||||
if client_args is not None and client_kwargs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both 'client_args' and 'client_kwargs'. "
|
||||
"Please use only 'client_kwargs' (client_args is deprecated).",
|
||||
)
|
||||
|
||||
if client_args is not None:
|
||||
logger.warning(
|
||||
"The parameter 'client_args' is deprecated and will be "
|
||||
"removed in a future version. Please use 'client_kwargs' "
|
||||
"instead. Automatically converting 'client_args' to "
|
||||
"'client_kwargs'.",
|
||||
)
|
||||
client_kwargs = client_args
|
||||
|
||||
if kwargs:
|
||||
logger.warning(
|
||||
"Unknown keyword arguments: %s. These will be ignored.",
|
||||
list(kwargs.keys()),
|
||||
)
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package by running "
|
||||
"`pip install anthropic`.",
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=api_key,
|
||||
**(client_kwargs or {}),
|
||||
)
|
||||
self.max_tokens = max_tokens
|
||||
self.thinking = thinking
|
||||
self.stream_tool_parsing = stream_tool_parsing
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**generate_kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from Anthropic chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that in format of:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of tools JSON schemas
|
||||
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "xxx",
|
||||
"description": "xxx",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "string",
|
||||
"description": "..."
|
||||
},
|
||||
# Add more parameters as needed
|
||||
},
|
||||
"required": ["param1"]
|
||||
}
|
||||
},
|
||||
# More schemas here
|
||||
]
|
||||
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
**generate_kwargs (`Any`):
|
||||
The keyword arguments for Anthropic chat completions API,
|
||||
e.g. `temperature`, `top_p`, etc. Please
|
||||
refer to the Anthropic API documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the Anthropic chat completions API."""
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**generate_kwargs,
|
||||
}
|
||||
if self.thinking and "thinking" not in kwargs:
|
||||
kwargs["thinking"] = self.thinking
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
# Handle deprecated "any" option with warning
|
||||
if tool_choice == "any":
|
||||
warnings.warn(
|
||||
'"any" is deprecated and will be removed in a future '
|
||||
"version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tool_choice = "required"
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
format_tool = _create_tool_from_base_model(structured_model)
|
||||
kwargs["tools"] = self._format_tools_json_schemas(
|
||||
[format_tool],
|
||||
)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(
|
||||
format_tool["function"]["name"],
|
||||
)
|
||||
|
||||
# Extract the system message
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["system"] = messages[0]["content"]
|
||||
messages = messages[1:]
|
||||
|
||||
kwargs["messages"] = messages
|
||||
|
||||
start_datetime = datetime.now()
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_anthropic_stream_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
parsed_response = await self._parse_anthropic_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_anthropic_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Message,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an Anthropic Message object, extract the content blocks and
|
||||
usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`Message`):
|
||||
Anthropic Message object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = []
|
||||
metadata = None
|
||||
|
||||
if hasattr(response, "content") and response.content:
|
||||
for content_block in response.content:
|
||||
if (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "thinking"
|
||||
):
|
||||
thinking_block = ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=content_block.thinking,
|
||||
)
|
||||
thinking_block["signature"] = content_block.signature
|
||||
content_blocks.append(thinking_block)
|
||||
|
||||
elif (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "text"
|
||||
):
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content_block.text,
|
||||
),
|
||||
)
|
||||
|
||||
elif (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "tool_use"
|
||||
):
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=content_block.id,
|
||||
name=content_block.name,
|
||||
input=content_block.input,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = content_block.input
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_anthropic_stream_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncStream,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an Anthropic streaming response, extract the content blocks
|
||||
and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncStream`):
|
||||
Anthropic AsyncStream object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in
|
||||
the streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
|
||||
usage = None
|
||||
text_buffer = ""
|
||||
thinking_buffer = ""
|
||||
thinking_signature = ""
|
||||
tool_calls = OrderedDict()
|
||||
tool_call_buffers = {}
|
||||
last_input_objs = {} # Store last input_obj for each tool_call
|
||||
res = None
|
||||
metadata = None
|
||||
|
||||
# Record the last yielded content to parse the tools' input
|
||||
last_content = None
|
||||
|
||||
async for event in response:
|
||||
content_changed = False
|
||||
thinking_changed = False
|
||||
|
||||
if event.type == "message_start":
|
||||
message = event.message
|
||||
if message.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=message.usage.input_tokens,
|
||||
output_tokens=getattr(
|
||||
message.usage,
|
||||
"output_tokens",
|
||||
0,
|
||||
),
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
block_index = event.index
|
||||
tool_block = event.content_block
|
||||
tool_calls[block_index] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_block.id,
|
||||
"name": tool_block.name,
|
||||
"input": "",
|
||||
}
|
||||
tool_call_buffers[block_index] = ""
|
||||
content_changed = True
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
block_index = event.index
|
||||
delta = event.delta
|
||||
if delta.type == "text_delta":
|
||||
text_buffer += delta.text
|
||||
content_changed = True
|
||||
elif delta.type == "thinking_delta":
|
||||
thinking_buffer += delta.thinking
|
||||
thinking_changed = True
|
||||
elif delta.type == "signature_delta":
|
||||
thinking_signature = delta.signature
|
||||
elif (
|
||||
delta.type == "input_json_delta"
|
||||
and block_index in tool_calls
|
||||
):
|
||||
tool_call_buffers[block_index] += delta.partial_json or ""
|
||||
tool_calls[block_index]["input"] = tool_call_buffers[
|
||||
block_index
|
||||
]
|
||||
content_changed = True
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if event.usage and usage:
|
||||
usage.output_tokens = event.usage.output_tokens
|
||||
|
||||
if (thinking_changed or content_changed) and usage:
|
||||
contents: list = []
|
||||
if thinking_buffer:
|
||||
thinking_block = ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking_buffer,
|
||||
)
|
||||
thinking_block["signature"] = thinking_signature
|
||||
contents.append(thinking_block)
|
||||
if text_buffer:
|
||||
contents.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text_buffer,
|
||||
),
|
||||
)
|
||||
for block_index, tool_call in tool_calls.items():
|
||||
input_str = tool_call["input"]
|
||||
tool_id = tool_call["id"]
|
||||
|
||||
# If parsing the tool input in streaming mode
|
||||
if self.stream_tool_parsing:
|
||||
repaired_input = _json_loads_with_repair(
|
||||
input_str or "{}",
|
||||
)
|
||||
# If the new repaired input is shorter than one in the
|
||||
# last chunk, use the last one to avoid regression
|
||||
last_input = last_input_objs.get(tool_id, {})
|
||||
if len(json.dumps(last_input)) > len(
|
||||
json.dumps(repaired_input),
|
||||
):
|
||||
repaired_input = last_input
|
||||
last_input_objs[tool_id] = repaired_input
|
||||
|
||||
else:
|
||||
repaired_input = {}
|
||||
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
input=repaired_input,
|
||||
raw_input=input_str,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = repaired_input
|
||||
|
||||
if contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
last_content = copy.deepcopy(contents)
|
||||
|
||||
# If stream_tool_parsing is False, yield last contents
|
||||
if not self.stream_tool_parsing and last_content and tool_calls:
|
||||
metadata = None
|
||||
# Update tool use blocks in last_contents inplace
|
||||
for block in last_content:
|
||||
if block.get("type") == "tool_use":
|
||||
block["input"] = input_obj = _json_loads_with_repair(
|
||||
block.get("raw_input") or "{}",
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_obj
|
||||
|
||||
yield ChatResponse(
|
||||
content=last_content,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the JSON schemas of the tool functions to the format that
|
||||
Anthropic API expects."""
|
||||
formatted_schemas = []
|
||||
for schema in schemas:
|
||||
assert (
|
||||
"function" in schema
|
||||
), f"Invalid schema: {schema}, expect key 'function'."
|
||||
|
||||
assert "name" in schema["function"], (
|
||||
f"Invalid schema: {schema}, "
|
||||
"expect key 'name' in 'function' field."
|
||||
)
|
||||
|
||||
formatted_schemas.append(
|
||||
{
|
||||
"name": schema["function"]["name"],
|
||||
"description": schema["function"].get("description", ""),
|
||||
"input_schema": schema["function"].get("parameters", {}),
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None,
|
||||
) -> dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
type_mapping = {
|
||||
"auto": {"type": "auto"},
|
||||
"none": {"type": "none"},
|
||||
"required": {"type": "any"},
|
||||
}
|
||||
if tool_choice in type_mapping:
|
||||
return type_mapping[tool_choice]
|
||||
|
||||
return {"type": "tool", "name": tool_choice}
|
||||
632
src/agentscope/model/_dashscope_model.py
Normal file
632
src/agentscope/model/_dashscope_model.py
Normal file
@@ -0,0 +1,632 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope API model classes."""
|
||||
import copy
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from aioitertools import iter as giter
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._model_usage import ChatUsage
|
||||
from .._utils._common import (
|
||||
_json_loads_with_repair,
|
||||
_create_tool_from_base_model,
|
||||
)
|
||||
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
from .._logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
MultiModalConversationResponse,
|
||||
)
|
||||
else:
|
||||
GenerationResponse = (
|
||||
"dashscope.api_entities.dashscope_response.GenerationResponse"
|
||||
)
|
||||
MultiModalConversationResponse = (
|
||||
"dashscope.api_entities.dashscope_response."
|
||||
"MultiModalConversationResponse"
|
||||
)
|
||||
|
||||
|
||||
class DashScopeChatModel(ChatModelBase):
|
||||
"""The DashScope chat model class, which unifies the Generation and
|
||||
MultimodalConversation APIs into one method.
|
||||
|
||||
This class provides a unified interface for DashScope API by automatically
|
||||
selecting between text-only (Generation API) and multimodal
|
||||
(MultiModalConversation API) endpoints. The `multimodality` parameter
|
||||
allows explicit control over API selection:
|
||||
|
||||
- When `multimodality=True`: Forces use of MultiModalConversation API
|
||||
for handling images, videos, and other multimodal inputs
|
||||
- When `multimodality=False`: Forces use of Generation API for
|
||||
text-only processing
|
||||
- When `multimodality=None` (default): Automatically selects the API
|
||||
based on model name (e.g., models with "-vl" suffix or starting
|
||||
with "qvq" will use MultiModalConversation API)
|
||||
|
||||
This design enables seamless switching between text and multimodal
|
||||
models without changing code structure, making it easier to work with
|
||||
DashScope's diverse model offerings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
stream: bool = True,
|
||||
enable_thinking: bool | None = None,
|
||||
multimodality: bool | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
base_http_api_url: str | None = None,
|
||||
stream_tool_parsing: bool = True,
|
||||
**_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the DashScope chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The model names.
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
stream (`bool`):
|
||||
The streaming output or not
|
||||
enable_thinking (`bool | None`, optional):
|
||||
Enable thinking or not, only support Qwen3, QwQ, DeepSeek-R1.
|
||||
Refer to `DashScope documentation
|
||||
<https://help.aliyun.com/zh/model-studio/deep-thinking>`_
|
||||
for more details.
|
||||
multimodality (`bool | None`, optional):
|
||||
Whether to use multimodal conversation API. If `True`,
|
||||
it will use `dashscope.MultiModalConversation.call`
|
||||
to process multimodal inputs such as images and text. If
|
||||
`False`, it will use
|
||||
`dashscope.aigc.generation.AioGeneration.call` to process
|
||||
text inputs. If `None` (default), the choice is based on
|
||||
the model name.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in DashScope API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
base_http_api_url (`str | None`, optional):
|
||||
The base URL for DashScope API requests. If not provided,
|
||||
the default base URL from the DashScope SDK will be used.
|
||||
stream_tool_parsing (`bool`, default to `True`):
|
||||
Whether to parse incomplete tool use JSON in streaming mode
|
||||
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
|
||||
is repaired to valid dicts (`{"a": "x"}`) in real-time for
|
||||
immediate tool function input. Otherwise, the input field
|
||||
remains {} until the final chunk arrives.
|
||||
**_kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
"""
|
||||
if enable_thinking and not stream:
|
||||
logger.info(
|
||||
"In DashScope API, `stream` must be True when "
|
||||
"`enable_thinking` is True. ",
|
||||
)
|
||||
stream = True
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.api_key = api_key
|
||||
self.enable_thinking = enable_thinking
|
||||
self.multimodality = multimodality
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
self.stream_tool_parsing = stream_tool_parsing
|
||||
|
||||
if base_http_api_url is not None:
|
||||
import dashscope
|
||||
|
||||
dashscope.base_http_api_url = base_http_api_url
|
||||
|
||||
# Load headers from environment variable if exists
|
||||
headers = os.getenv("DASHSCOPE_API_HEADERS")
|
||||
if headers:
|
||||
try:
|
||||
headers = json.loads(str(headers))
|
||||
if not isinstance(headers, dict):
|
||||
raise json.JSONDecodeError("", "", 0)
|
||||
|
||||
if self.generate_kwargs.get("headers"):
|
||||
headers.update(self.generate_kwargs["headers"])
|
||||
|
||||
self.generate_kwargs["headers"] = headers
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Failed to parse DASHSCOPE_API_HEADERS environment "
|
||||
"variable as JSON. It should be a JSON object.",
|
||||
)
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from the dashscope
|
||||
Generation/MultimodalConversation API by the given arguments.
|
||||
|
||||
.. note:: We unify the dashscope generation and multimodal conversation
|
||||
APIs into one method, since they support similar arguments and share
|
||||
the same functionality.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required.
|
||||
tools (`list[dict] | None`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool name.
|
||||
Note: DashScope API only supports "auto" and "none", so
|
||||
"required" will be converted to "auto".
|
||||
For more details, please refer to
|
||||
https://help.aliyun.com/zh/model-studio/qwen-function-calling
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for DashScope chat completions API,
|
||||
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
|
||||
refer to `DashScope documentation
|
||||
<https://help.aliyun.com/zh/dashscope/developer-reference/api-details>`_
|
||||
for more detailed arguments.
|
||||
"""
|
||||
import dashscope
|
||||
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**kwargs,
|
||||
"result_format": "message",
|
||||
# In agentscope, the `incremental_output` must be `True` when
|
||||
# `self.stream` is True
|
||||
"incremental_output": self.stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
# Handle deprecated "any" option with warning
|
||||
if tool_choice in ["any", "required"]:
|
||||
warnings.warn(
|
||||
f"'{tool_choice}' is not supported by DashScope API. "
|
||||
"It will be converted to 'auto'.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tool_choice = "auto"
|
||||
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if (
|
||||
self.enable_thinking is not None
|
||||
and "enable_thinking" not in kwargs
|
||||
):
|
||||
kwargs["enable_thinking"] = self.enable_thinking
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
format_tool = _create_tool_from_base_model(structured_model)
|
||||
kwargs["tools"] = self._format_tools_json_schemas(
|
||||
[format_tool],
|
||||
)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(
|
||||
format_tool["function"]["name"],
|
||||
)
|
||||
|
||||
start_datetime = datetime.now()
|
||||
if self.multimodality or (
|
||||
self.multimodality is None
|
||||
and (
|
||||
self.model_name.startswith(
|
||||
"qvq",
|
||||
)
|
||||
or "-vl" in self.model_name
|
||||
)
|
||||
):
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
response = await dashscope.aigc.generation.AioGeneration.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_dashscope_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
parsed_response = await self._parse_dashscope_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
async def _parse_dashscope_stream_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Union[
|
||||
AsyncGenerator[GenerationResponse, None],
|
||||
Generator[MultiModalConversationResponse, None, None],
|
||||
],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, Any]:
|
||||
"""Given a DashScope streaming response generator, extract the content
|
||||
blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (
|
||||
`Union[AsyncGenerator[GenerationResponse, None], Generator[ \
|
||||
MultiModalConversationResponse, None, None]]`
|
||||
):
|
||||
DashScope streaming response generator (GenerationResponse or
|
||||
MultiModalConversationResponse) to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[ChatResponse, Any]:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
acc_content, acc_thinking_content = "", ""
|
||||
acc_tool_calls = collections.defaultdict(dict)
|
||||
last_input_objs = {} # Store last input_obj for each tool_call
|
||||
metadata = None
|
||||
last_content = None
|
||||
usage = None
|
||||
|
||||
async for chunk in giter(response):
|
||||
if chunk.status_code != HTTPStatus.OK:
|
||||
raise RuntimeError(
|
||||
f"Failed to get response from _ API: {chunk}",
|
||||
)
|
||||
|
||||
message = chunk.output.choices[0].message
|
||||
|
||||
# Update reasoning content
|
||||
if isinstance(message.get("reasoning_content"), str):
|
||||
acc_thinking_content += message["reasoning_content"]
|
||||
|
||||
# Update text content
|
||||
if isinstance(message.content, str):
|
||||
acc_content += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
acc_content += item["text"]
|
||||
|
||||
# Update tool calls
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
index = tool_call.get("index", 0)
|
||||
|
||||
if "id" in tool_call and tool_call["id"] != acc_tool_calls[
|
||||
index
|
||||
].get("id"):
|
||||
acc_tool_calls[index]["id"] = (
|
||||
acc_tool_calls[index].get("id", "") + tool_call["id"]
|
||||
)
|
||||
|
||||
if "function" in tool_call:
|
||||
func = tool_call["function"]
|
||||
if "name" in func:
|
||||
acc_tool_calls[index]["name"] = (
|
||||
acc_tool_calls[index].get("name", "")
|
||||
+ func["name"]
|
||||
)
|
||||
|
||||
if "arguments" in func:
|
||||
acc_tool_calls[index]["arguments"] = (
|
||||
acc_tool_calls[index].get("arguments", "")
|
||||
+ func["arguments"]
|
||||
)
|
||||
|
||||
# Build content blocks (always include thinking and text)
|
||||
content_blocks: list[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
|
||||
if acc_thinking_content:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=acc_thinking_content,
|
||||
),
|
||||
)
|
||||
|
||||
if acc_content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=acc_content,
|
||||
),
|
||||
)
|
||||
|
||||
for tool_call in acc_tool_calls.values():
|
||||
# Only add intermediate tool use blocks if
|
||||
# stream_tool_parsing is True
|
||||
tool_id = tool_call.get("id", "")
|
||||
input_str = tool_call.get("arguments")
|
||||
|
||||
# If parsing the tool input in streaming mode
|
||||
if self.stream_tool_parsing:
|
||||
repaired_input = _json_loads_with_repair(
|
||||
input_str or "{}",
|
||||
)
|
||||
# If the new repaired input is shorter than one in the last
|
||||
# chunk, use the last one to avoid regression
|
||||
last_input = last_input_objs.get(tool_id, {})
|
||||
if len(json.dumps(last_input)) > len(
|
||||
json.dumps(repaired_input),
|
||||
):
|
||||
repaired_input = last_input
|
||||
last_input_objs[tool_id] = repaired_input
|
||||
|
||||
else:
|
||||
# Otherwise, keep input as empty dict until the final chunk
|
||||
repaired_input = {}
|
||||
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_id,
|
||||
name=tool_call.get("name", ""),
|
||||
input=repaired_input,
|
||||
raw_input=input_str,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = repaired_input
|
||||
|
||||
if chunk.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage.input_tokens,
|
||||
output_tokens=chunk.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
metadata=chunk.usage,
|
||||
)
|
||||
|
||||
if content_blocks:
|
||||
parsed_chunk = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield parsed_chunk
|
||||
last_content = copy.deepcopy(content_blocks)
|
||||
|
||||
# If stream_tool_parsing is False, we need to parse the final tool
|
||||
# use inputs here
|
||||
if not self.stream_tool_parsing and last_content and acc_tool_calls:
|
||||
metadata = None
|
||||
# Update tool use blocks in last_contents inplace
|
||||
for block in last_content:
|
||||
if block.get("type") == "tool_use":
|
||||
block["input"] = input_obj = _json_loads_with_repair(
|
||||
str(block.get("raw_input") or "{}"),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_obj
|
||||
|
||||
yield ChatResponse(
|
||||
content=last_content,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def _parse_dashscope_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Union[
|
||||
GenerationResponse,
|
||||
MultiModalConversationResponse,
|
||||
],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given a DashScope GenerationResponse object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (
|
||||
`Union[GenerationResponse, MultiModalConversationResponse]`
|
||||
):
|
||||
Dashscope GenerationResponse | MultiModalConversationResponse
|
||||
object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
# Collect the content blocks from the response.
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(response)
|
||||
|
||||
content_blocks: List[TextBlock | ToolUseBlock] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
message = response.output.choices[0].message
|
||||
content = message.get("content")
|
||||
|
||||
if response.output.choices[0].message.get("content") not in [
|
||||
None,
|
||||
"",
|
||||
[],
|
||||
]:
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=item["text"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content,
|
||||
),
|
||||
)
|
||||
|
||||
if message.get("tool_calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
input_ = _json_loads_with_repair(
|
||||
tool_call["function"].get(
|
||||
"arguments",
|
||||
"{}",
|
||||
)
|
||||
or "{}",
|
||||
)
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
name=tool_call["function"]["name"],
|
||||
input=input_,
|
||||
id=tool_call["id"],
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_
|
||||
|
||||
# Usage information
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
metadata=response.usage,
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schema into required format for DashScope API.
|
||||
|
||||
Args:
|
||||
schemas (`dict[str, dict[str, Any]]`):
|
||||
The tools JSON schemas.
|
||||
"""
|
||||
# Check schemas format
|
||||
for value in schemas:
|
||||
if (
|
||||
not isinstance(value, dict)
|
||||
or "type" not in value
|
||||
or value["type"] != "function"
|
||||
or "function" not in value
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each schema must be a dict with 'type' as 'function' "
|
||||
f"and 'function' key, got {value}",
|
||||
)
|
||||
|
||||
return schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None,
|
||||
) -> str | dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model. For more
|
||||
details, please refer to
|
||||
https://help.aliyun.com/zh/model-studio/qwen-function-calling
|
||||
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
if tool_choice in ["auto", "none"]:
|
||||
return tool_choice
|
||||
if tool_choice == "required":
|
||||
return "auto"
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
646
src/agentscope/model/_gemini_model.py
Normal file
646
src/agentscope/model/_gemini_model.py
Normal file
@@ -0,0 +1,646 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# mypy: disable-error-code="dict-item"
|
||||
"""The Google Gemini model in agentscope."""
|
||||
import base64
|
||||
import copy
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
AsyncIterator,
|
||||
Literal,
|
||||
Type,
|
||||
List,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._logging import logger
|
||||
from .._utils._common import _json_loads_with_repair
|
||||
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
|
||||
from ._model_usage import ChatUsage
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.genai.types import GenerateContentResponse
|
||||
else:
|
||||
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
|
||||
|
||||
|
||||
def _flatten_json_schema(schema: dict) -> dict:
|
||||
"""Flatten a JSON schema by resolving all $ref references.
|
||||
|
||||
.. note::
|
||||
Gemini API does not support `$defs` and `$ref` in JSON schemas.
|
||||
This function resolves all `$ref` references by inlining the
|
||||
referenced definitions, producing a self-contained schema without
|
||||
any references.
|
||||
|
||||
Args:
|
||||
schema (`dict`):
|
||||
The JSON schema that may contain `$defs` and `$ref` references.
|
||||
|
||||
Returns:
|
||||
`dict`:
|
||||
A flattened JSON schema with all references resolved inline.
|
||||
"""
|
||||
# Deep copy to avoid modifying the original schema
|
||||
schema = copy.deepcopy(schema)
|
||||
|
||||
# Extract $defs if present
|
||||
defs = schema.pop("$defs", {})
|
||||
|
||||
def _resolve_ref(obj: Any, visited: set | None = None) -> Any:
|
||||
"""Recursively resolve $ref references in the schema."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if not isinstance(obj, dict):
|
||||
if isinstance(obj, list):
|
||||
return [_resolve_ref(item, visited.copy()) for item in obj]
|
||||
return obj
|
||||
|
||||
# Handle $ref
|
||||
if "$ref" in obj:
|
||||
ref_path = obj["$ref"]
|
||||
# Extract definition name from "#/$defs/DefinitionName"
|
||||
if ref_path.startswith("#/$defs/"):
|
||||
def_name = ref_path[len("#/$defs/") :]
|
||||
|
||||
# Prevent infinite recursion for circular references
|
||||
if def_name in visited:
|
||||
logger.warning(
|
||||
"Circular reference detected for '%s' in tool schema",
|
||||
def_name,
|
||||
)
|
||||
return {
|
||||
"type": "object",
|
||||
"description": f"(circular: {def_name})",
|
||||
}
|
||||
|
||||
visited.add(def_name)
|
||||
|
||||
if def_name in defs:
|
||||
# Recursively resolve any nested refs in the definition
|
||||
resolved = _resolve_ref(
|
||||
defs[def_name],
|
||||
visited.copy(),
|
||||
)
|
||||
# Merge any additional properties from the original object
|
||||
# (excluding $ref itself)
|
||||
for key, value in obj.items():
|
||||
if key != "$ref":
|
||||
resolved[key] = _resolve_ref(value, visited.copy())
|
||||
return resolved
|
||||
|
||||
# If we can't resolve the ref, return as-is (shouldn't happen)
|
||||
return obj
|
||||
|
||||
# Recursively process all nested objects
|
||||
result = {}
|
||||
for key, value in obj.items():
|
||||
result[key] = _resolve_ref(value, visited.copy())
|
||||
|
||||
return result
|
||||
|
||||
return _resolve_ref(schema)
|
||||
|
||||
|
||||
class GeminiChatModel(ChatModelBase):
|
||||
"""The Google Gemini chat model class in agentscope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
stream: bool = True,
|
||||
thinking_config: dict | None = None,
|
||||
client_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Gemini chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the Gemini model to use, e.g. "gemini-2.5-flash".
|
||||
api_key (`str`):
|
||||
The API key for Google Gemini.
|
||||
stream (`bool`, default `True`):
|
||||
Whether to use streaming output or not.
|
||||
thinking_config (`dict | None`, optional):
|
||||
Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc.
|
||||
Refer to https://ai.google.dev/gemini-api/docs/thinking for
|
||||
more details.
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of thinking_config
|
||||
|
||||
{
|
||||
"include_thoughts": True, # enable thoughts or not
|
||||
"thinking_budget": 1024 # Max tokens for reasoning
|
||||
}
|
||||
|
||||
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments to initialize the Gemini client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in Gemini API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
"""
|
||||
|
||||
# Handle deprecated client_args parameter from kwargs
|
||||
client_args = kwargs.pop("client_args", None)
|
||||
if client_args is not None and client_kwargs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both 'client_args' and 'client_kwargs'. "
|
||||
"Please use only 'client_kwargs' (client_args is deprecated).",
|
||||
)
|
||||
|
||||
if client_args is not None:
|
||||
logger.warning(
|
||||
"The parameter 'client_args' is deprecated and will be "
|
||||
"removed in a future version. Please use 'client_kwargs' "
|
||||
"instead. Automatically converting 'client_args' to "
|
||||
"'client_kwargs'.",
|
||||
)
|
||||
client_kwargs = client_args
|
||||
|
||||
if kwargs:
|
||||
logger.warning(
|
||||
"Unknown keyword arguments: %s. These will be ignored.",
|
||||
list(kwargs.keys()),
|
||||
)
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install gemini Python sdk with "
|
||||
"`pip install -q -U google-genai`",
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = genai.Client(
|
||||
api_key=api_key,
|
||||
**(client_kwargs or {}),
|
||||
)
|
||||
self.thinking_config = thinking_config
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**config_kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Call the Gemini model with the provided arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required.
|
||||
tools (`list[dict] | None`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool name.
|
||||
For more details, please refer to
|
||||
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
For more details, please refer to
|
||||
https://ai.google.dev/gemini-api/docs/structured-output
|
||||
|
||||
**config_kwargs (`Any`):
|
||||
The keyword arguments for Gemini chat completions API.
|
||||
"""
|
||||
|
||||
config: dict = {
|
||||
"thinking_config": self.thinking_config,
|
||||
**self.generate_kwargs,
|
||||
**config_kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
config["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
# Handle deprecated "any" option with warning
|
||||
if tool_choice == "any":
|
||||
warnings.warn(
|
||||
'"any" is deprecated and will be removed in a future '
|
||||
"version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tool_choice = "required"
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
config["tool_config"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
config.pop("tools", None)
|
||||
config.pop("tool_config", None)
|
||||
config["response_mime_type"] = "application/json"
|
||||
config["response_schema"] = structured_model
|
||||
|
||||
# Prepare the arguments for the Gemini API call
|
||||
kwargs: dict[str, JSONSerializableObject] = {
|
||||
"model": self.model_name,
|
||||
"contents": messages,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
start_datetime = datetime.now()
|
||||
if self.stream:
|
||||
response = await self.client.aio.models.generate_content_stream(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._parse_gemini_stream_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# non-streaming
|
||||
response = await self.client.aio.models.generate_content(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
parsed_response = self._parse_gemini_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_gemini_stream_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncIterator[GenerateContentResponse],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given a Gemini streaming generation response, extract the
|
||||
content blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncIterator[GenerateContentResponse]`):
|
||||
Gemini GenerateContentResponse async iterator to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
|
||||
text = ""
|
||||
thinking = ""
|
||||
tool_calls: list[ToolUseBlock] = []
|
||||
metadata: dict | None = None
|
||||
async for chunk in response:
|
||||
if (
|
||||
chunk.candidates
|
||||
and chunk.candidates[0].content
|
||||
and chunk.candidates[0].content.parts
|
||||
):
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if part.text:
|
||||
if part.thought:
|
||||
thinking += part.text
|
||||
else:
|
||||
text += part.text
|
||||
|
||||
if part.function_call:
|
||||
keyword_args = part.function_call.args or {}
|
||||
# .. note:: Gemini API always returns None for
|
||||
# function_call.id, so we use thought_signature
|
||||
# as the unique identifier for tool
|
||||
# calls when available. That maybe
|
||||
# infeasible someday, but Gemini
|
||||
# requires the thought_signature for some
|
||||
# llms like gemini-3-pro
|
||||
|
||||
if part.thought_signature:
|
||||
call_id = base64.b64encode(
|
||||
part.thought_signature,
|
||||
).decode("utf-8")
|
||||
else:
|
||||
call_id = part.function_call.id
|
||||
|
||||
tool_calls.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=call_id,
|
||||
name=part.function_call.name,
|
||||
input=keyword_args,
|
||||
raw_input=json.dumps(
|
||||
keyword_args,
|
||||
ensure_ascii=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Text parts
|
||||
if text and structured_model:
|
||||
metadata = _json_loads_with_repair(text)
|
||||
|
||||
usage = None
|
||||
if chunk.usage_metadata:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage_metadata.prompt_token_count,
|
||||
output_tokens=chunk.usage_metadata.total_token_count
|
||||
- chunk.usage_metadata.prompt_token_count,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
# The content blocks for the current chunk
|
||||
content_blocks: list = []
|
||||
|
||||
if thinking:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking,
|
||||
),
|
||||
)
|
||||
|
||||
if text:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text,
|
||||
),
|
||||
)
|
||||
|
||||
yield ChatResponse(
|
||||
content=content_blocks + tool_calls,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _parse_gemini_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: GenerateContentResponse,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given a Gemini chat completion response object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`GenerateContentResponse`):
|
||||
The Gemini generation response object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
metadata: dict | None = None
|
||||
tool_calls: list = []
|
||||
|
||||
if (
|
||||
response.candidates
|
||||
and response.candidates[0].content
|
||||
and response.candidates[0].content.parts
|
||||
):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text:
|
||||
if part.thought:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=part.text,
|
||||
),
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=part.text,
|
||||
),
|
||||
)
|
||||
|
||||
if part.function_call:
|
||||
keyword_args = part.function_call.args or {}
|
||||
# .. note:: Gemini API always returns None for
|
||||
# function_call.id, so we use thought_signature
|
||||
# as the unique identifier for tool
|
||||
# calls when available. That maybe infeasible
|
||||
# someday, but Gemini requires the thought_signature
|
||||
# for some llms like gemini-3-pro
|
||||
|
||||
if part.thought_signature:
|
||||
call_id = base64.b64encode(
|
||||
part.thought_signature,
|
||||
).decode("utf-8")
|
||||
else:
|
||||
call_id = part.function_call.id
|
||||
|
||||
tool_calls.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=call_id,
|
||||
name=part.function_call.name,
|
||||
input=keyword_args,
|
||||
raw_input=json.dumps(
|
||||
keyword_args,
|
||||
ensure_ascii=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# For the structured output case
|
||||
if response.text and structured_model:
|
||||
metadata = _json_loads_with_repair(response.text)
|
||||
|
||||
if response.usage_metadata:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage_metadata.prompt_token_count,
|
||||
output_tokens=response.usage_metadata.total_token_count
|
||||
- response.usage_metadata.prompt_token_count,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
else:
|
||||
usage = None
|
||||
|
||||
return ChatResponse(
|
||||
content=content_blocks + tool_calls,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schema into required format for Gemini API.
|
||||
|
||||
.. note:: Gemini API does not support `$defs` and `$ref` in JSON
|
||||
schemas. This function resolves all `$ref` references by inlining the
|
||||
referenced definitions, producing a self-contained schema without
|
||||
any references.
|
||||
|
||||
Args:
|
||||
schemas (`dict[str, Any]`):
|
||||
The tools JSON schemas.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]:
|
||||
A list containing a dictionary with the
|
||||
"function_declarations" key, which maps to a list of
|
||||
function definitions.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
:caption: Example tool schemas of Gemini API
|
||||
|
||||
# Input JSON schema
|
||||
schemas = [
|
||||
{
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'execute_shell_command',
|
||||
'description': 'xxx',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'xxx.'
|
||||
},
|
||||
'timeout': {
|
||||
'type': 'integer',
|
||||
'default': 300
|
||||
}
|
||||
},
|
||||
'required': ['command']
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Output format (Gemini API expected):
|
||||
[
|
||||
{
|
||||
'function_declarations': [
|
||||
{
|
||||
'name': 'execute_shell_command',
|
||||
'description': 'xxx.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'xxx.'
|
||||
},
|
||||
'timeout': {
|
||||
'type': 'integer',
|
||||
'default': 300
|
||||
}
|
||||
},
|
||||
'required': ['command']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
"""
|
||||
function_declarations = []
|
||||
for schema in schemas:
|
||||
if "function" not in schema:
|
||||
continue
|
||||
func = schema["function"].copy()
|
||||
# Flatten the parameters schema to resolve $ref references
|
||||
if "parameters" in func:
|
||||
func["parameters"] = _flatten_json_schema(func["parameters"])
|
||||
function_declarations.append(func)
|
||||
|
||||
return [{"function_declarations": function_declarations}]
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None,
|
||||
) -> dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "required"] | str | None`, \
|
||||
default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool name.
|
||||
For more details, please refer to
|
||||
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
|
||||
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
mode_mapping = {
|
||||
"auto": "AUTO",
|
||||
"none": "NONE",
|
||||
"required": "ANY",
|
||||
}
|
||||
mode = mode_mapping.get(tool_choice)
|
||||
if mode:
|
||||
return {"function_calling_config": {"mode": mode}}
|
||||
return {
|
||||
"function_calling_config": {
|
||||
"mode": "ANY",
|
||||
"allowed_function_names": [tool_choice],
|
||||
},
|
||||
}
|
||||
77
src/agentscope/model/_model_base.py
Normal file
77
src/agentscope/model/_model_base.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The chat model base class."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, Any
|
||||
|
||||
from ._model_response import ChatResponse
|
||||
|
||||
|
||||
_TOOL_CHOICE_MODES = ["auto", "none", "required"]
|
||||
|
||||
|
||||
class ChatModelBase:
|
||||
"""Base class for chat models."""
|
||||
|
||||
model_name: str
|
||||
"""The model name"""
|
||||
|
||||
stream: bool
|
||||
"""Is the model output streaming or not"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""Initialize the chat model base class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the model
|
||||
stream (`bool`):
|
||||
Whether the model output is streaming or not
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.stream = stream
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
pass
|
||||
|
||||
def _validate_tool_choice(
|
||||
self,
|
||||
tool_choice: str,
|
||||
tools: list[dict] | None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate tool_choice parameter.
|
||||
|
||||
Args:
|
||||
tool_choice (`str`):
|
||||
Tool choice mode or function name
|
||||
tools (`list[dict] | None`):
|
||||
Available tools list
|
||||
Raises:
|
||||
TypeError: If tool_choice is not string
|
||||
ValueError: If tool_choice is invalid
|
||||
"""
|
||||
if not isinstance(tool_choice, str):
|
||||
raise TypeError(
|
||||
f"tool_choice must be str, got {type(tool_choice)}",
|
||||
)
|
||||
if tool_choice in _TOOL_CHOICE_MODES:
|
||||
return
|
||||
|
||||
available_functions = [tool["function"]["name"] for tool in tools]
|
||||
|
||||
if tool_choice not in available_functions:
|
||||
all_options = _TOOL_CHOICE_MODES + available_functions
|
||||
raise ValueError(
|
||||
f"Invalid tool_choice '{tool_choice}'. "
|
||||
f"Available options: {', '.join(sorted(all_options))}",
|
||||
)
|
||||
42
src/agentscope/model/_model_response.py
Normal file
42
src/agentscope/model/_model_response.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model response module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Sequence
|
||||
|
||||
from ._model_usage import ChatUsage
|
||||
from .._utils._common import _get_timestamp
|
||||
from .._utils._mixin import DictMixin
|
||||
from ..message import (
|
||||
TextBlock,
|
||||
ToolUseBlock,
|
||||
ThinkingBlock,
|
||||
AudioBlock,
|
||||
)
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResponse(DictMixin):
|
||||
"""The response of chat models."""
|
||||
|
||||
content: Sequence[TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock]
|
||||
"""The content of the chat response, which can include text blocks,
|
||||
tool use blocks, or thinking blocks."""
|
||||
|
||||
id: str = field(default_factory=lambda: _get_timestamp(True))
|
||||
"""The unique identifier formatter """
|
||||
|
||||
created_at: str = field(default_factory=_get_timestamp)
|
||||
"""When the response was created"""
|
||||
|
||||
type: Literal["chat"] = field(default_factory=lambda: "chat")
|
||||
"""The type of the response, which is always 'chat'."""
|
||||
|
||||
usage: ChatUsage | None = field(default_factory=lambda: None)
|
||||
"""The usage information of the chat response, if available."""
|
||||
|
||||
metadata: dict[str, JSONSerializableObject] | None = field(
|
||||
default_factory=lambda: None,
|
||||
)
|
||||
"""The metadata of the chat response"""
|
||||
26
src/agentscope/model/_model_usage.py
Normal file
26
src/agentscope/model/_model_usage.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model usage class in agentscope."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Any
|
||||
|
||||
from .._utils._mixin import DictMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatUsage(DictMixin):
|
||||
"""The usage of a chat model API invocation."""
|
||||
|
||||
input_tokens: int
|
||||
"""The number of input tokens."""
|
||||
|
||||
output_tokens: int
|
||||
"""The number of output tokens."""
|
||||
|
||||
time: float
|
||||
"""The time used in seconds."""
|
||||
|
||||
type: Literal["chat"] = field(default_factory=lambda: "chat")
|
||||
"""The type of the usage, must be `chat`."""
|
||||
|
||||
metadata: dict[str, Any] | None = field(default_factory=lambda: None)
|
||||
"""The metadata of the usage."""
|
||||
355
src/agentscope/model/_ollama_model.py
Normal file
355
src/agentscope/model/_ollama_model.py
Normal file
@@ -0,0 +1,355 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Model wrapper for Ollama models."""
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import ChatResponse
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import _json_loads_with_repair
|
||||
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ollama._types import ChatResponse as OllamaChatResponse
|
||||
else:
|
||||
OllamaChatResponse = "ollama._types.ChatResponse"
|
||||
|
||||
|
||||
class OllamaChatModel(ChatModelBase):
|
||||
"""The Ollama chat model class in agentscope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
stream: bool = False,
|
||||
options: dict = None,
|
||||
keep_alive: str = "5m",
|
||||
enable_thinking: bool | None = None,
|
||||
host: str | None = None,
|
||||
client_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Ollama chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the model.
|
||||
stream (`bool`, default `True`):
|
||||
Streaming mode or not.
|
||||
options (`dict`, default `None`):
|
||||
Additional parameters to pass to the Ollama API. These can
|
||||
include temperature etc.
|
||||
keep_alive (`str`, default `"5m"`):
|
||||
Duration to keep the model loaded in memory. The format is a
|
||||
number followed by a unit suffix (s for seconds, m for minutes
|
||||
, h for hours).
|
||||
enable_thinking (`bool | None`, default `None`)
|
||||
Whether enable thinking or not, only for models such as qwen3,
|
||||
deepseek-r1, etc. For more details, please refer to
|
||||
https://ollama.com/search?c=thinking
|
||||
host (`str | None`, default `None`):
|
||||
The host address of the Ollama server. If None, uses the
|
||||
default address (typically http://localhost:11434).
|
||||
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments to initialize the Ollama client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in Ollama API generation.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments to pass to the base chat model
|
||||
class.
|
||||
"""
|
||||
|
||||
try:
|
||||
import ollama
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The package ollama is not found. Please install it by "
|
||||
'running command `pip install "ollama>=0.1.7"`',
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = ollama.AsyncClient(
|
||||
host=host,
|
||||
**(client_kwargs or {}),
|
||||
**kwargs,
|
||||
)
|
||||
self.options = options
|
||||
self.keep_alive = keep_alive
|
||||
self.think = enable_thinking
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from Ollama chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Ollama doesn't support `tool_choice` argument yet.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for Ollama chat completions API,
|
||||
e.g. `think`etc. Please refer to the Ollama API
|
||||
documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the Ollama chat completions API.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
"options": self.options,
|
||||
"keep_alive": self.keep_alive,
|
||||
**self.generate_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.think is not None and "think" not in kwargs:
|
||||
kwargs["think"] = self.think
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
logger.warning("Ollama does not support tool_choice yet, ignored.")
|
||||
|
||||
if structured_model:
|
||||
kwargs["format"] = structured_model.model_json_schema()
|
||||
|
||||
start_datetime = datetime.now()
|
||||
response = await self.client.chat(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_ollama_stream_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
parsed_response = await self._parse_ollama_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_ollama_stream_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncIterator[OllamaChatResponse],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an Ollama streaming completion response, extract the
|
||||
content blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncIterator[OllamaChatResponse]`):
|
||||
Ollama streaming response async iterator to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[ChatResponse, None]:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
|
||||
"""
|
||||
accumulated_text = ""
|
||||
acc_thinking_content = ""
|
||||
tool_calls = OrderedDict() # Store tool calls
|
||||
metadata: dict | None = None
|
||||
|
||||
async for chunk in response:
|
||||
# Handle text content
|
||||
msg = chunk.message
|
||||
acc_thinking_content += msg.thinking or ""
|
||||
accumulated_text += msg.content or ""
|
||||
|
||||
# Handle tool calls
|
||||
for idx, tool_call in enumerate(msg.tool_calls or []):
|
||||
function = tool_call.function
|
||||
tool_id = f"{idx}_{function.name}"
|
||||
tool_calls[tool_id] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_id,
|
||||
"name": function.name,
|
||||
"input": function.arguments,
|
||||
"raw_input": json.dumps(function.arguments),
|
||||
}
|
||||
# Calculate usage statistics
|
||||
current_time = (datetime.now() - start_datetime).total_seconds()
|
||||
usage = ChatUsage(
|
||||
input_tokens=getattr(chunk, "prompt_eval_count", 0) or 0,
|
||||
output_tokens=getattr(chunk, "eval_count", 0) or 0,
|
||||
time=current_time,
|
||||
)
|
||||
# Create content blocks
|
||||
contents: list = []
|
||||
|
||||
if acc_thinking_content:
|
||||
contents.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=acc_thinking_content,
|
||||
),
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
contents.append(TextBlock(type="text", text=accumulated_text))
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(accumulated_text)
|
||||
|
||||
# Add tool call blocks
|
||||
for tool_call in tool_calls.values():
|
||||
try:
|
||||
input_data = tool_call["input"]
|
||||
if isinstance(input_data, str):
|
||||
input_data = _json_loads_with_repair(input_data)
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
input=input_data,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing tool call input: {e}")
|
||||
|
||||
# Generate response when there's new content or at final chunk
|
||||
if chunk.done or contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
|
||||
async def _parse_ollama_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: OllamaChatResponse,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an Ollama chat completion response object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`OllamaChatResponse`):
|
||||
Ollama OllamaChatResponse object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`ChatResponse`:
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
if response.message.thinking:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=response.message.thinking,
|
||||
),
|
||||
)
|
||||
|
||||
if response.message.content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=response.message.content,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(
|
||||
response.message.content,
|
||||
)
|
||||
|
||||
for idx, tool_call in enumerate(response.message.tool_calls or []):
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=f"{idx}_{tool_call.function.name}",
|
||||
name=tool_call.function.name,
|
||||
input=tool_call.function.arguments,
|
||||
raw_input=json.dumps(tool_call.function.arguments),
|
||||
),
|
||||
)
|
||||
|
||||
usage = None
|
||||
if "prompt_eval_count" in response and "eval_count" in response:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.get("prompt_eval_count", 0),
|
||||
output_tokens=response.get("eval_count", 0),
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schemas to the Ollama format."""
|
||||
return schemas
|
||||
650
src/agentscope/model/_openai_model.py
Normal file
650
src/agentscope/model/_openai_model.py
Normal file
@@ -0,0 +1,650 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""OpenAI Chat model class."""
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
AsyncGenerator,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import ChatResponse
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import _json_loads_with_repair
|
||||
from ..message import (
|
||||
ToolUseBlock,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
AudioBlock,
|
||||
Base64Source,
|
||||
)
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai import AsyncStream
|
||||
else:
|
||||
ChatCompletion = "openai.types.chat.ChatCompletion"
|
||||
AsyncStream = "openai.types.chat.AsyncStream"
|
||||
|
||||
|
||||
def _format_audio_data_for_qwen_omni(messages: list[dict]) -> None:
|
||||
"""Qwen-omni uses OpenAI-compatible API but requires different audio
|
||||
data format than OpenAI with "data:;base64," prefix.
|
||||
Refer to `Qwen-omni documentation
|
||||
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`_
|
||||
for more details.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
The list of message dictionaries from OpenAI formatter.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
for block in msg["content"]:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and "input_audio" in block
|
||||
and isinstance(block["input_audio"].get("data"), str)
|
||||
):
|
||||
if not block["input_audio"]["data"].startswith("http"):
|
||||
block["input_audio"]["data"] = (
|
||||
"data:;base64," + block["input_audio"]["data"]
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatModel(ChatModelBase):
|
||||
"""The OpenAI chat model class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str | None = None,
|
||||
stream: bool = True,
|
||||
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
||||
organization: str = None,
|
||||
stream_tool_parsing: bool = True,
|
||||
client_type: Literal["openai", "azure"] = "openai",
|
||||
client_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the openai client.
|
||||
|
||||
Args:
|
||||
model_name (`str`, default `None`):
|
||||
The name of the model to use in OpenAI API.
|
||||
api_key (`str`, default `None`):
|
||||
The API key for OpenAI API. If not specified, it will
|
||||
be read from the environment variable `OPENAI_API_KEY`.
|
||||
stream (`bool`, default `True`):
|
||||
Whether to use streaming output or not.
|
||||
reasoning_effort (`Literal["low", "medium", "high"] | None`, \
|
||||
optional):
|
||||
Reasoning effort, supported for o3, o4, etc. Please refer to
|
||||
`OpenAI documentation
|
||||
<https://platform.openai.com/docs/guides/reasoning?api-mode=chat>`_
|
||||
for more details.
|
||||
organization (`str`, default `None`):
|
||||
The organization ID for OpenAI API. If not specified, it will
|
||||
be read from the environment variable `OPENAI_ORGANIZATION`.
|
||||
stream_tool_parsing (`bool`, default to `True`):
|
||||
Whether to parse incomplete tool use JSON during streaming
|
||||
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
|
||||
is repaired to valid dicts ({"a": "x"}) in real-time for
|
||||
immediate tool function input. Otherwise, the input field
|
||||
remains {} until the final chunk arrives.
|
||||
client_type (`Literal["openai", "azure"]`, default `openai`):
|
||||
Selects which OpenAI-compatible client to initialize.
|
||||
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments to initialize the OpenAI client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in OpenAI API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments.
|
||||
"""
|
||||
|
||||
# Handle deprecated client_args parameter from kwargs
|
||||
client_args = kwargs.pop("client_args", None)
|
||||
if client_args is not None and client_kwargs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both 'client_args' and 'client_kwargs'. "
|
||||
"Please use only 'client_kwargs' (client_args is deprecated).",
|
||||
)
|
||||
|
||||
if client_args is not None:
|
||||
logger.warning(
|
||||
"The parameter 'client_args' is deprecated and will be "
|
||||
"removed in a future version. Please use 'client_kwargs' "
|
||||
"instead. Automatically converting 'client_args' to "
|
||||
"'client_kwargs'.",
|
||||
)
|
||||
client_kwargs = client_args
|
||||
|
||||
if kwargs:
|
||||
logger.warning(
|
||||
"Unknown keyword arguments: %s. These will be ignored.",
|
||||
list(kwargs.keys()),
|
||||
)
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
import openai
|
||||
|
||||
if client_type not in ("openai", "azure"):
|
||||
raise ValueError(
|
||||
"Invalid client_type. Supported values: 'openai', 'azure'.",
|
||||
)
|
||||
|
||||
if client_type == "azure":
|
||||
self.client = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
**(client_kwargs or {}),
|
||||
)
|
||||
else:
|
||||
self.client = openai.AsyncClient(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
**(client_kwargs or {}),
|
||||
)
|
||||
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream_tool_parsing = stream_tool_parsing
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from OpenAI chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
For more details, please refer to the `official document
|
||||
<https://platform.openai.com/docs/guides/structured-outputs>`_
|
||||
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for OpenAI chat completions API,
|
||||
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
|
||||
refer to the OpenAI API documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the OpenAI chat completions API.
|
||||
"""
|
||||
|
||||
# checking messages
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
"OpenAI `messages` field expected type `list`, "
|
||||
f"got `{type(messages)}` instead.",
|
||||
)
|
||||
if not all("role" in msg and "content" in msg for msg in messages):
|
||||
raise ValueError(
|
||||
"Each message in the 'messages' list must contain a 'role' "
|
||||
"and 'content' key for OpenAI API.",
|
||||
)
|
||||
|
||||
# Qwen-omni requires different base64 audio format from openai
|
||||
if "omni" in self.model_name.lower():
|
||||
_format_audio_data_for_qwen_omni(messages)
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
if self.reasoning_effort and "reasoning_effort" not in kwargs:
|
||||
kwargs["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
# Handle deprecated "any" option with warning
|
||||
if tool_choice == "any":
|
||||
warnings.warn(
|
||||
'"any" is deprecated and will be removed in a future '
|
||||
"version.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
tool_choice = "required"
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if self.stream:
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
start_datetime = datetime.now()
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
kwargs.pop("stream", None)
|
||||
kwargs.pop("tools", None)
|
||||
kwargs.pop("tool_choice", None)
|
||||
kwargs["response_format"] = structured_model
|
||||
if not self.stream:
|
||||
response = await self.client.chat.completions.parse(**kwargs)
|
||||
else:
|
||||
response = self.client.chat.completions.stream(**kwargs)
|
||||
return self._parse_openai_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
else:
|
||||
response = await self.client.chat.completions.create(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_openai_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
parsed_response = self._parse_openai_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _parse_openai_stream_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncStream,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an OpenAI streaming completion response, extract the content
|
||||
blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncStream`):
|
||||
OpenAI AsyncStream object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in
|
||||
the streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
usage, res = None, None
|
||||
text = ""
|
||||
thinking = ""
|
||||
audio = ""
|
||||
tool_calls = OrderedDict()
|
||||
last_input_objs = {} # Store last input_obj for each tool_call
|
||||
metadata: dict | None = None
|
||||
contents: List[
|
||||
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
|
||||
] = []
|
||||
last_contents = None
|
||||
|
||||
async with response as stream:
|
||||
async for item in stream:
|
||||
if structured_model:
|
||||
if item.type != "chunk":
|
||||
continue
|
||||
chunk = item.chunk
|
||||
else:
|
||||
chunk = item
|
||||
|
||||
if chunk.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage.prompt_tokens,
|
||||
output_tokens=chunk.usage.completion_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
metadata=chunk.usage,
|
||||
)
|
||||
|
||||
if not chunk.choices:
|
||||
if usage and contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
|
||||
thinking += (
|
||||
getattr(choice.delta, "reasoning_content", None) or ""
|
||||
)
|
||||
text += getattr(choice.delta, "content", None) or ""
|
||||
|
||||
if (
|
||||
hasattr(choice.delta, "audio")
|
||||
and "data" in choice.delta.audio
|
||||
):
|
||||
audio += choice.delta.audio["data"]
|
||||
if (
|
||||
hasattr(choice.delta, "audio")
|
||||
and "transcript" in choice.delta.audio
|
||||
):
|
||||
text += choice.delta.audio["transcript"]
|
||||
|
||||
for tool_call in (
|
||||
getattr(choice.delta, "tool_calls", None) or []
|
||||
):
|
||||
if tool_call.index in tool_calls:
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_calls[tool_call.index][
|
||||
"input"
|
||||
] += tool_call.function.arguments
|
||||
|
||||
else:
|
||||
tool_calls[tool_call.index] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": tool_call.function.arguments or "",
|
||||
}
|
||||
|
||||
contents = []
|
||||
|
||||
if thinking:
|
||||
contents.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking,
|
||||
),
|
||||
)
|
||||
|
||||
if audio:
|
||||
media_type = self.generate_kwargs.get("audio", {}).get(
|
||||
"format",
|
||||
"wav",
|
||||
)
|
||||
contents.append(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
data=audio,
|
||||
media_type=f"audio/{media_type}",
|
||||
type="base64",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if text:
|
||||
contents.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(text)
|
||||
|
||||
for tool_call in tool_calls.values():
|
||||
input_str = tool_call["input"]
|
||||
tool_id = tool_call["id"]
|
||||
|
||||
# If parsing the tool input in streaming mode
|
||||
if self.stream_tool_parsing:
|
||||
repaired_input = _json_loads_with_repair(
|
||||
input_str or "{}",
|
||||
)
|
||||
# If the new repaired input is shorter than one in the
|
||||
# last chunk, use the last one to avoid regression
|
||||
last_input = last_input_objs.get(tool_id, {})
|
||||
if len(json.dumps(last_input)) > len(
|
||||
json.dumps(repaired_input),
|
||||
):
|
||||
repaired_input = last_input
|
||||
last_input_objs[tool_id] = repaired_input
|
||||
|
||||
else:
|
||||
# Otherwise, keep input as empty dict until the final
|
||||
# chunk
|
||||
repaired_input = {}
|
||||
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_id,
|
||||
name=tool_call["name"],
|
||||
input=repaired_input,
|
||||
raw_input=input_str,
|
||||
),
|
||||
)
|
||||
|
||||
if contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
last_contents = copy.deepcopy(contents)
|
||||
|
||||
# If stream_tool_parsing is False, yield last contents
|
||||
if not self.stream_tool_parsing and tool_calls and last_contents:
|
||||
metadata = None
|
||||
# Update tool use blocks in last_contents inplace
|
||||
for block in last_contents:
|
||||
if block.get("type") == "tool_use":
|
||||
block["input"] = input_obj = _json_loads_with_repair(
|
||||
str(block.get("raw_input") or "{}"),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_obj
|
||||
|
||||
yield ChatResponse(
|
||||
content=last_contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _parse_openai_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: ChatCompletion,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an OpenAI chat completion response object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`ChatCompletion`):
|
||||
OpenAI ChatCompletion object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[
|
||||
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
|
||||
] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if (
|
||||
hasattr(choice.message, "reasoning_content")
|
||||
and choice.message.reasoning_content is not None
|
||||
):
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=response.choices[0].message.reasoning_content,
|
||||
),
|
||||
)
|
||||
|
||||
if choice.message.content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=response.choices[0].message.content,
|
||||
),
|
||||
)
|
||||
if choice.message.audio:
|
||||
media_type = self.generate_kwargs.get("audio", {}).get(
|
||||
"format",
|
||||
"mp3",
|
||||
)
|
||||
content_blocks.append(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
data=choice.message.audio.data,
|
||||
media_type=f"audio/{media_type}",
|
||||
type="base64",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if choice.message.audio.transcript:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=choice.message.audio.transcript,
|
||||
),
|
||||
)
|
||||
|
||||
for tool_call in choice.message.tool_calls or []:
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=_json_loads_with_repair(
|
||||
tool_call.function.arguments,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = choice.message.parsed.model_dump()
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.prompt_tokens,
|
||||
output_tokens=response.usage.completion_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
metadata=response.usage,
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schemas to the OpenAI format."""
|
||||
return schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "required"] | str | None,
|
||||
) -> str | dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "required", or specific tool name.
|
||||
For more details, please refer to
|
||||
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
mode_mapping = {
|
||||
"auto": "auto",
|
||||
"none": "none",
|
||||
"required": "required",
|
||||
}
|
||||
if tool_choice in mode_mapping:
|
||||
return mode_mapping[tool_choice]
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
67
src/agentscope/model/_trinity_model.py
Normal file
67
src/agentscope/model/_trinity_model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A model class for RL Training with Trinity-RFT."""
|
||||
from typing import (
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
from ._openai_model import OpenAIChatModel
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
else:
|
||||
AsyncOpenAI = "openai.AsyncOpenAI"
|
||||
|
||||
|
||||
@deprecated(
|
||||
"TrinityChatModel is deprecated. Please use OpenAIChatModel directly.",
|
||||
)
|
||||
class TrinityChatModel(OpenAIChatModel):
|
||||
"""A model class for RL Training with Trinity-RFT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
openai_async_client: AsyncOpenAI,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
enable_thinking: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Initialize the Trinity model class.
|
||||
|
||||
Args:
|
||||
openai_async_client (`AsyncOpenAI`):
|
||||
The OpenAI async client instance provided by Trinity-RFT.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
Additional keyword arguments to pass to the model's generate
|
||||
method. Defaults to None.
|
||||
enable_thinking (`bool`, optional):
|
||||
Whether to enable the model's thinking capability. Only
|
||||
applicable for Qwen3 series models. Defaults to None.
|
||||
"""
|
||||
model_name = getattr(openai_async_client, "model_path", None)
|
||||
if model_name is None:
|
||||
raise ValueError(
|
||||
"The provided openai_async_client does not have a "
|
||||
"`model_path` attribute. Please ensure you are using "
|
||||
"the instance provided by Trinity-RFT.",
|
||||
)
|
||||
super().__init__(
|
||||
model_name=model_name,
|
||||
api_key="EMPTY",
|
||||
generate_kwargs=generate_kwargs,
|
||||
stream=False, # RL training does not support streaming
|
||||
)
|
||||
if enable_thinking is not None:
|
||||
if "chat_template_kwargs" not in self.generate_kwargs:
|
||||
self.generate_kwargs["chat_template_kwargs"] = {}
|
||||
assert isinstance(
|
||||
self.generate_kwargs["chat_template_kwargs"],
|
||||
dict,
|
||||
), "chat_template_kwargs must be a dictionary."
|
||||
self.generate_kwargs["chat_template_kwargs"][
|
||||
"enable_thinking"
|
||||
] = enable_thinking
|
||||
# change the client instance to the provided one
|
||||
self.client = openai_async_client
|
||||
Reference in New Issue
Block a user