chore: initial import of standalone agentscope project
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled

This commit is contained in:
2026-03-02 18:21:40 +08:00
commit a842f1861f
561 changed files with 91892 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""The model module."""
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._dashscope_model import DashScopeChatModel
from ._openai_model import OpenAIChatModel
from ._anthropic_model import AnthropicChatModel
from ._ollama_model import OllamaChatModel
from ._gemini_model import GeminiChatModel
from ._trinity_model import TrinityChatModel
__all__ = [
"ChatModelBase",
"ChatResponse",
"DashScopeChatModel",
"OpenAIChatModel",
"AnthropicChatModel",
"OllamaChatModel",
"GeminiChatModel",
"TrinityChatModel",
]

View File

@@ -0,0 +1,590 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches, too-many-statements
"""The Anthropic API model classes."""
import copy
import json
import warnings
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
TYPE_CHECKING,
List,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import (
_json_loads_with_repair,
_create_tool_from_base_model,
)
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types._json import JSONSerializableObject
if TYPE_CHECKING:
from anthropic.types.message import Message
from anthropic import AsyncStream
else:
Message = "anthropic.types.message.Message"
AsyncStream = "anthropic.AsyncStream"
class AnthropicChatModel(ChatModelBase):
"""The Anthropic model wrapper for AgentScope."""
def __init__(
self,
model_name: str,
api_key: str | None = None,
max_tokens: int = 2048,
stream: bool = True,
thinking: dict | None = None,
stream_tool_parsing: bool = True,
client_kwargs: dict[str, JSONSerializableObject] | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Anthropic chat model.
Args:
model_name (`str`):
The model names.
api_key (`str`):
The anthropic API key.
stream (`bool`):
The streaming output or not
max_tokens (`int`):
Limit the maximum token count the model can generate.
thinking (`dict | None`, default `None`):
Configuration for Claude's internal reasoning process.
.. code-block:: python
:caption: Example of thinking
{
"type": "enabled" | "disabled",
"budget_tokens": 1024
}
stream_tool_parsing (`bool`, default to `True`):
Whether to parse incomplete tool use JSON during streaming
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
is repaired to valid dicts ({"a": "x"}) in real-time for
immediate tool function input. Otherwise, the input field
remains {} until the final chunk arrives.
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments to initialize the Anthropic client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Anthropic API generation,
e.g. `temperature`, `seed`.
**kwargs (`Any`):
Additional keyword arguments.
"""
# Handle deprecated client_args parameter from kwargs
client_args = kwargs.pop("client_args", None)
if client_args is not None and client_kwargs is not None:
raise ValueError(
"Cannot specify both 'client_args' and 'client_kwargs'. "
"Please use only 'client_kwargs' (client_args is deprecated).",
)
if client_args is not None:
logger.warning(
"The parameter 'client_args' is deprecated and will be "
"removed in a future version. Please use 'client_kwargs' "
"instead. Automatically converting 'client_args' to "
"'client_kwargs'.",
)
client_kwargs = client_args
if kwargs:
logger.warning(
"Unknown keyword arguments: %s. These will be ignored.",
list(kwargs.keys()),
)
try:
import anthropic
except ImportError as e:
raise ImportError(
"Please install the `anthropic` package by running "
"`pip install anthropic`.",
) from e
super().__init__(model_name, stream)
self.client = anthropic.AsyncAnthropic(
api_key=api_key,
**(client_kwargs or {}),
)
self.max_tokens = max_tokens
self.thinking = thinking
self.stream_tool_parsing = stream_tool_parsing
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**generate_kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from Anthropic chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that in format of:
.. code-block:: python
:caption: Example of tools JSON schemas
[
{
"type": "function",
"function": {
"name": "xxx",
"description": "xxx",
"parameters": {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "..."
},
# Add more parameters as needed
},
"required": ["param1"]
}
},
# More schemas here
]
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool
name. For more details, please refer to
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
**generate_kwargs (`Any`):
The keyword arguments for Anthropic chat completions API,
e.g. `temperature`, `top_p`, etc. Please
refer to the Anthropic API documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the Anthropic chat completions API."""
kwargs: dict[str, Any] = {
"model": self.model_name,
"max_tokens": self.max_tokens,
"stream": self.stream,
**self.generate_kwargs,
**generate_kwargs,
}
if self.thinking and "thinking" not in kwargs:
kwargs["thinking"] = self.thinking
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
# Handle deprecated "any" option with warning
if tool_choice == "any":
warnings.warn(
'"any" is deprecated and will be removed in a future '
"version.",
DeprecationWarning,
)
tool_choice = "required"
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
format_tool = _create_tool_from_base_model(structured_model)
kwargs["tools"] = self._format_tools_json_schemas(
[format_tool],
)
kwargs["tool_choice"] = self._format_tool_choice(
format_tool["function"]["name"],
)
# Extract the system message
if messages[0]["role"] == "system":
kwargs["system"] = messages[0]["content"]
messages = messages[1:]
kwargs["messages"] = messages
start_datetime = datetime.now()
response = await self.client.messages.create(**kwargs)
if self.stream:
return self._parse_anthropic_stream_completion_response(
start_datetime,
response,
structured_model,
)
# Non-streaming response
parsed_response = await self._parse_anthropic_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_anthropic_completion_response(
self,
start_datetime: datetime,
response: Message,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Anthropic Message object, extract the content blocks and
usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`Message`):
Anthropic Message object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = []
metadata = None
if hasattr(response, "content") and response.content:
for content_block in response.content:
if (
hasattr(content_block, "type")
and content_block.type == "thinking"
):
thinking_block = ThinkingBlock(
type="thinking",
thinking=content_block.thinking,
)
thinking_block["signature"] = content_block.signature
content_blocks.append(thinking_block)
elif (
hasattr(content_block, "type")
and content_block.type == "text"
):
content_blocks.append(
TextBlock(
type="text",
text=content_block.text,
),
)
elif (
hasattr(content_block, "type")
and content_block.type == "tool_use"
):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=content_block.id,
name=content_block.name,
input=content_block.input,
),
)
if structured_model:
metadata = content_block.input
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
async def _parse_anthropic_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncStream,
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Anthropic streaming response, extract the content blocks
and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncStream`):
Anthropic AsyncStream object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in
the streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
usage = None
text_buffer = ""
thinking_buffer = ""
thinking_signature = ""
tool_calls = OrderedDict()
tool_call_buffers = {}
last_input_objs = {} # Store last input_obj for each tool_call
res = None
metadata = None
# Record the last yielded content to parse the tools' input
last_content = None
async for event in response:
content_changed = False
thinking_changed = False
if event.type == "message_start":
message = event.message
if message.usage:
usage = ChatUsage(
input_tokens=message.usage.input_tokens,
output_tokens=getattr(
message.usage,
"output_tokens",
0,
),
time=(datetime.now() - start_datetime).total_seconds(),
)
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
block_index = event.index
tool_block = event.content_block
tool_calls[block_index] = {
"type": "tool_use",
"id": tool_block.id,
"name": tool_block.name,
"input": "",
}
tool_call_buffers[block_index] = ""
content_changed = True
elif event.type == "content_block_delta":
block_index = event.index
delta = event.delta
if delta.type == "text_delta":
text_buffer += delta.text
content_changed = True
elif delta.type == "thinking_delta":
thinking_buffer += delta.thinking
thinking_changed = True
elif delta.type == "signature_delta":
thinking_signature = delta.signature
elif (
delta.type == "input_json_delta"
and block_index in tool_calls
):
tool_call_buffers[block_index] += delta.partial_json or ""
tool_calls[block_index]["input"] = tool_call_buffers[
block_index
]
content_changed = True
elif event.type == "message_delta":
if event.usage and usage:
usage.output_tokens = event.usage.output_tokens
if (thinking_changed or content_changed) and usage:
contents: list = []
if thinking_buffer:
thinking_block = ThinkingBlock(
type="thinking",
thinking=thinking_buffer,
)
thinking_block["signature"] = thinking_signature
contents.append(thinking_block)
if text_buffer:
contents.append(
TextBlock(
type="text",
text=text_buffer,
),
)
for block_index, tool_call in tool_calls.items():
input_str = tool_call["input"]
tool_id = tool_call["id"]
# If parsing the tool input in streaming mode
if self.stream_tool_parsing:
repaired_input = _json_loads_with_repair(
input_str or "{}",
)
# If the new repaired input is shorter than one in the
# last chunk, use the last one to avoid regression
last_input = last_input_objs.get(tool_id, {})
if len(json.dumps(last_input)) > len(
json.dumps(repaired_input),
):
repaired_input = last_input
last_input_objs[tool_id] = repaired_input
else:
repaired_input = {}
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=repaired_input,
raw_input=input_str,
),
)
if structured_model:
metadata = repaired_input
if contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
last_content = copy.deepcopy(contents)
# If stream_tool_parsing is False, yield last contents
if not self.stream_tool_parsing and last_content and tool_calls:
metadata = None
# Update tool use blocks in last_contents inplace
for block in last_content:
if block.get("type") == "tool_use":
block["input"] = input_obj = _json_loads_with_repair(
block.get("raw_input") or "{}",
)
if structured_model:
metadata = input_obj
yield ChatResponse(
content=last_content,
usage=usage,
metadata=metadata,
)
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the JSON schemas of the tool functions to the format that
Anthropic API expects."""
formatted_schemas = []
for schema in schemas:
assert (
"function" in schema
), f"Invalid schema: {schema}, expect key 'function'."
assert "name" in schema["function"], (
f"Invalid schema: {schema}, "
"expect key 'name' in 'function' field."
)
formatted_schemas.append(
{
"name": schema["function"]["name"],
"description": schema["function"].get("description", ""),
"input_schema": schema["function"].get("parameters", {}),
},
)
return formatted_schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "required"] | str | None,
) -> dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool
name. For more details, please refer to
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
type_mapping = {
"auto": {"type": "auto"},
"none": {"type": "none"},
"required": {"type": "any"},
}
if tool_choice in type_mapping:
return type_mapping[tool_choice]
return {"type": "tool", "name": tool_choice}

View File

@@ -0,0 +1,632 @@
# -*- coding: utf-8 -*-
"""The dashscope API model classes."""
import copy
import collections
import json
import os
import warnings
from datetime import datetime
from http import HTTPStatus
from typing import (
Any,
AsyncGenerator,
Generator,
Union,
TYPE_CHECKING,
List,
Literal,
Type,
)
from pydantic import BaseModel
from aioitertools import iter as giter
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._model_usage import ChatUsage
from .._utils._common import (
_json_loads_with_repair,
_create_tool_from_base_model,
)
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types import JSONSerializableObject
from .._logging import logger
if TYPE_CHECKING:
from dashscope.api_entities.dashscope_response import GenerationResponse
from dashscope.api_entities.dashscope_response import (
MultiModalConversationResponse,
)
else:
GenerationResponse = (
"dashscope.api_entities.dashscope_response.GenerationResponse"
)
MultiModalConversationResponse = (
"dashscope.api_entities.dashscope_response."
"MultiModalConversationResponse"
)
class DashScopeChatModel(ChatModelBase):
"""The DashScope chat model class, which unifies the Generation and
MultimodalConversation APIs into one method.
This class provides a unified interface for DashScope API by automatically
selecting between text-only (Generation API) and multimodal
(MultiModalConversation API) endpoints. The `multimodality` parameter
allows explicit control over API selection:
- When `multimodality=True`: Forces use of MultiModalConversation API
for handling images, videos, and other multimodal inputs
- When `multimodality=False`: Forces use of Generation API for
text-only processing
- When `multimodality=None` (default): Automatically selects the API
based on model name (e.g., models with "-vl" suffix or starting
with "qvq" will use MultiModalConversation API)
This design enables seamless switching between text and multimodal
models without changing code structure, making it easier to work with
DashScope's diverse model offerings.
"""
def __init__(
self,
model_name: str,
api_key: str,
stream: bool = True,
enable_thinking: bool | None = None,
multimodality: bool | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
base_http_api_url: str | None = None,
stream_tool_parsing: bool = True,
**_kwargs: Any,
) -> None:
"""Initialize the DashScope chat model.
Args:
model_name (`str`):
The model names.
api_key (`str`):
The dashscope API key.
stream (`bool`):
The streaming output or not
enable_thinking (`bool | None`, optional):
Enable thinking or not, only support Qwen3, QwQ, DeepSeek-R1.
Refer to `DashScope documentation
<https://help.aliyun.com/zh/model-studio/deep-thinking>`_
for more details.
multimodality (`bool | None`, optional):
Whether to use multimodal conversation API. If `True`,
it will use `dashscope.MultiModalConversation.call`
to process multimodal inputs such as images and text. If
`False`, it will use
`dashscope.aigc.generation.AioGeneration.call` to process
text inputs. If `None` (default), the choice is based on
the model name.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in DashScope API generation,
e.g. `temperature`, `seed`.
base_http_api_url (`str | None`, optional):
The base URL for DashScope API requests. If not provided,
the default base URL from the DashScope SDK will be used.
stream_tool_parsing (`bool`, default to `True`):
Whether to parse incomplete tool use JSON in streaming mode
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
is repaired to valid dicts (`{"a": "x"}`) in real-time for
immediate tool function input. Otherwise, the input field
remains {} until the final chunk arrives.
**_kwargs (`Any`):
Additional keyword arguments.
"""
if enable_thinking and not stream:
logger.info(
"In DashScope API, `stream` must be True when "
"`enable_thinking` is True. ",
)
stream = True
super().__init__(model_name, stream)
self.api_key = api_key
self.enable_thinking = enable_thinking
self.multimodality = multimodality
self.generate_kwargs = generate_kwargs or {}
self.stream_tool_parsing = stream_tool_parsing
if base_http_api_url is not None:
import dashscope
dashscope.base_http_api_url = base_http_api_url
# Load headers from environment variable if exists
headers = os.getenv("DASHSCOPE_API_HEADERS")
if headers:
try:
headers = json.loads(str(headers))
if not isinstance(headers, dict):
raise json.JSONDecodeError("", "", 0)
if self.generate_kwargs.get("headers"):
headers.update(self.generate_kwargs["headers"])
self.generate_kwargs["headers"] = headers
except json.JSONDecodeError:
logger.warning(
"Failed to parse DASHSCOPE_API_HEADERS environment "
"variable as JSON. It should be a JSON object.",
)
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from the dashscope
Generation/MultimodalConversation API by the given arguments.
.. note:: We unify the dashscope generation and multimodal conversation
APIs into one method, since they support similar arguments and share
the same functionality.
Args:
messages (`list[dict[str, Any]]`):
A list of dictionaries, where `role` and `content` fields are
required.
tools (`list[dict] | None`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool name.
Note: DashScope API only supports "auto" and "none", so
"required" will be converted to "auto".
For more details, please refer to
https://help.aliyun.com/zh/model-studio/qwen-function-calling
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
**kwargs (`Any`):
The keyword arguments for DashScope chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
refer to `DashScope documentation
<https://help.aliyun.com/zh/dashscope/developer-reference/api-details>`_
for more detailed arguments.
"""
import dashscope
kwargs = {
"messages": messages,
"model": self.model_name,
"stream": self.stream,
**self.generate_kwargs,
**kwargs,
"result_format": "message",
# In agentscope, the `incremental_output` must be `True` when
# `self.stream` is True
"incremental_output": self.stream,
}
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
# Handle deprecated "any" option with warning
if tool_choice in ["any", "required"]:
warnings.warn(
f"'{tool_choice}' is not supported by DashScope API. "
"It will be converted to 'auto'.",
DeprecationWarning,
)
tool_choice = "auto"
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if (
self.enable_thinking is not None
and "enable_thinking" not in kwargs
):
kwargs["enable_thinking"] = self.enable_thinking
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
format_tool = _create_tool_from_base_model(structured_model)
kwargs["tools"] = self._format_tools_json_schemas(
[format_tool],
)
kwargs["tool_choice"] = self._format_tool_choice(
format_tool["function"]["name"],
)
start_datetime = datetime.now()
if self.multimodality or (
self.multimodality is None
and (
self.model_name.startswith(
"qvq",
)
or "-vl" in self.model_name
)
):
response = dashscope.MultiModalConversation.call(
api_key=self.api_key,
**kwargs,
)
else:
response = await dashscope.aigc.generation.AioGeneration.call(
api_key=self.api_key,
**kwargs,
)
if self.stream:
return self._parse_dashscope_stream_response(
start_datetime,
response,
structured_model,
)
parsed_response = await self._parse_dashscope_generation_response(
start_datetime,
response,
structured_model,
)
return parsed_response
# pylint: disable=too-many-branches, too-many-statements
async def _parse_dashscope_stream_response(
self,
start_datetime: datetime,
response: Union[
AsyncGenerator[GenerationResponse, None],
Generator[MultiModalConversationResponse, None, None],
],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, Any]:
"""Given a DashScope streaming response generator, extract the content
blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (
`Union[AsyncGenerator[GenerationResponse, None], Generator[ \
MultiModalConversationResponse, None, None]]`
):
DashScope streaming response generator (GenerationResponse or
MultiModalConversationResponse) to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
AsyncGenerator[ChatResponse, Any]:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
acc_content, acc_thinking_content = "", ""
acc_tool_calls = collections.defaultdict(dict)
last_input_objs = {} # Store last input_obj for each tool_call
metadata = None
last_content = None
usage = None
async for chunk in giter(response):
if chunk.status_code != HTTPStatus.OK:
raise RuntimeError(
f"Failed to get response from _ API: {chunk}",
)
message = chunk.output.choices[0].message
# Update reasoning content
if isinstance(message.get("reasoning_content"), str):
acc_thinking_content += message["reasoning_content"]
# Update text content
if isinstance(message.content, str):
acc_content += message.content
elif isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and "text" in item:
acc_content += item["text"]
# Update tool calls
for tool_call in message.get("tool_calls", []):
index = tool_call.get("index", 0)
if "id" in tool_call and tool_call["id"] != acc_tool_calls[
index
].get("id"):
acc_tool_calls[index]["id"] = (
acc_tool_calls[index].get("id", "") + tool_call["id"]
)
if "function" in tool_call:
func = tool_call["function"]
if "name" in func:
acc_tool_calls[index]["name"] = (
acc_tool_calls[index].get("name", "")
+ func["name"]
)
if "arguments" in func:
acc_tool_calls[index]["arguments"] = (
acc_tool_calls[index].get("arguments", "")
+ func["arguments"]
)
# Build content blocks (always include thinking and text)
content_blocks: list[TextBlock | ToolUseBlock | ThinkingBlock] = []
if acc_thinking_content:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=acc_thinking_content,
),
)
if acc_content:
content_blocks.append(
TextBlock(
type="text",
text=acc_content,
),
)
for tool_call in acc_tool_calls.values():
# Only add intermediate tool use blocks if
# stream_tool_parsing is True
tool_id = tool_call.get("id", "")
input_str = tool_call.get("arguments")
# If parsing the tool input in streaming mode
if self.stream_tool_parsing:
repaired_input = _json_loads_with_repair(
input_str or "{}",
)
# If the new repaired input is shorter than one in the last
# chunk, use the last one to avoid regression
last_input = last_input_objs.get(tool_id, {})
if len(json.dumps(last_input)) > len(
json.dumps(repaired_input),
):
repaired_input = last_input
last_input_objs[tool_id] = repaired_input
else:
# Otherwise, keep input as empty dict until the final chunk
repaired_input = {}
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_id,
name=tool_call.get("name", ""),
input=repaired_input,
raw_input=input_str,
),
)
if structured_model:
metadata = repaired_input
if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.input_tokens,
output_tokens=chunk.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
metadata=chunk.usage,
)
if content_blocks:
parsed_chunk = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
yield parsed_chunk
last_content = copy.deepcopy(content_blocks)
# If stream_tool_parsing is False, we need to parse the final tool
# use inputs here
if not self.stream_tool_parsing and last_content and acc_tool_calls:
metadata = None
# Update tool use blocks in last_contents inplace
for block in last_content:
if block.get("type") == "tool_use":
block["input"] = input_obj = _json_loads_with_repair(
str(block.get("raw_input") or "{}"),
)
if structured_model:
metadata = input_obj
yield ChatResponse(
content=last_content,
usage=usage,
metadata=metadata,
)
async def _parse_dashscope_generation_response(
self,
start_datetime: datetime,
response: Union[
GenerationResponse,
MultiModalConversationResponse,
],
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given a DashScope GenerationResponse object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (
`Union[GenerationResponse, MultiModalConversationResponse]`
):
Dashscope GenerationResponse | MultiModalConversationResponse
object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
# Collect the content blocks from the response.
if response.status_code != 200:
raise RuntimeError(response)
content_blocks: List[TextBlock | ToolUseBlock] = []
metadata: dict | None = None
message = response.output.choices[0].message
content = message.get("content")
if response.output.choices[0].message.get("content") not in [
None,
"",
[],
]:
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and "text" in item:
content_blocks.append(
TextBlock(
type="text",
text=item["text"],
),
)
else:
content_blocks.append(
TextBlock(
type="text",
text=content,
),
)
if message.get("tool_calls"):
for tool_call in message["tool_calls"]:
input_ = _json_loads_with_repair(
tool_call["function"].get(
"arguments",
"{}",
)
or "{}",
)
content_blocks.append(
ToolUseBlock(
type="tool_use",
name=tool_call["function"]["name"],
input=input_,
id=tool_call["id"],
),
)
if structured_model:
metadata = input_
# Usage information
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
metadata=response.usage,
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schema into required format for DashScope API.
Args:
schemas (`dict[str, dict[str, Any]]`):
The tools JSON schemas.
"""
# Check schemas format
for value in schemas:
if (
not isinstance(value, dict)
or "type" not in value
or value["type"] != "function"
or "function" not in value
):
raise ValueError(
f"Each schema must be a dict with 'type' as 'function' "
f"and 'function' key, got {value}",
)
return schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "required"] | str | None,
) -> str | dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model. For more
details, please refer to
https://help.aliyun.com/zh/model-studio/qwen-function-calling
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
if tool_choice in ["auto", "none"]:
return tool_choice
if tool_choice == "required":
return "auto"
return {"type": "function", "function": {"name": tool_choice}}

View File

@@ -0,0 +1,646 @@
# -*- coding: utf-8 -*-
# mypy: disable-error-code="dict-item"
"""The Google Gemini model in agentscope."""
import base64
import copy
import warnings
from datetime import datetime
import json
from typing import (
AsyncGenerator,
Any,
TYPE_CHECKING,
AsyncIterator,
Literal,
Type,
List,
)
from pydantic import BaseModel
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
from ._model_usage import ChatUsage
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ..tracing import trace_llm
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from google.genai.types import GenerateContentResponse
else:
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
def _flatten_json_schema(schema: dict) -> dict:
"""Flatten a JSON schema by resolving all $ref references.
.. note::
Gemini API does not support `$defs` and `$ref` in JSON schemas.
This function resolves all `$ref` references by inlining the
referenced definitions, producing a self-contained schema without
any references.
Args:
schema (`dict`):
The JSON schema that may contain `$defs` and `$ref` references.
Returns:
`dict`:
A flattened JSON schema with all references resolved inline.
"""
# Deep copy to avoid modifying the original schema
schema = copy.deepcopy(schema)
# Extract $defs if present
defs = schema.pop("$defs", {})
def _resolve_ref(obj: Any, visited: set | None = None) -> Any:
"""Recursively resolve $ref references in the schema."""
if visited is None:
visited = set()
if not isinstance(obj, dict):
if isinstance(obj, list):
return [_resolve_ref(item, visited.copy()) for item in obj]
return obj
# Handle $ref
if "$ref" in obj:
ref_path = obj["$ref"]
# Extract definition name from "#/$defs/DefinitionName"
if ref_path.startswith("#/$defs/"):
def_name = ref_path[len("#/$defs/") :]
# Prevent infinite recursion for circular references
if def_name in visited:
logger.warning(
"Circular reference detected for '%s' in tool schema",
def_name,
)
return {
"type": "object",
"description": f"(circular: {def_name})",
}
visited.add(def_name)
if def_name in defs:
# Recursively resolve any nested refs in the definition
resolved = _resolve_ref(
defs[def_name],
visited.copy(),
)
# Merge any additional properties from the original object
# (excluding $ref itself)
for key, value in obj.items():
if key != "$ref":
resolved[key] = _resolve_ref(value, visited.copy())
return resolved
# If we can't resolve the ref, return as-is (shouldn't happen)
return obj
# Recursively process all nested objects
result = {}
for key, value in obj.items():
result[key] = _resolve_ref(value, visited.copy())
return result
return _resolve_ref(schema)
class GeminiChatModel(ChatModelBase):
"""The Google Gemini chat model class in agentscope."""
def __init__(
self,
model_name: str,
api_key: str,
stream: bool = True,
thinking_config: dict | None = None,
client_kwargs: dict[str, JSONSerializableObject] | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Gemini chat model.
Args:
model_name (`str`):
The name of the Gemini model to use, e.g. "gemini-2.5-flash".
api_key (`str`):
The API key for Google Gemini.
stream (`bool`, default `True`):
Whether to use streaming output or not.
thinking_config (`dict | None`, optional):
Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc.
Refer to https://ai.google.dev/gemini-api/docs/thinking for
more details.
.. code-block:: python
:caption: Example of thinking_config
{
"include_thoughts": True, # enable thoughts or not
"thinking_budget": 1024 # Max tokens for reasoning
}
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments to initialize the Gemini client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Gemini API generation,
e.g. `temperature`, `seed`.
**kwargs (`Any`):
Additional keyword arguments.
"""
# Handle deprecated client_args parameter from kwargs
client_args = kwargs.pop("client_args", None)
if client_args is not None and client_kwargs is not None:
raise ValueError(
"Cannot specify both 'client_args' and 'client_kwargs'. "
"Please use only 'client_kwargs' (client_args is deprecated).",
)
if client_args is not None:
logger.warning(
"The parameter 'client_args' is deprecated and will be "
"removed in a future version. Please use 'client_kwargs' "
"instead. Automatically converting 'client_args' to "
"'client_kwargs'.",
)
client_kwargs = client_args
if kwargs:
logger.warning(
"Unknown keyword arguments: %s. These will be ignored.",
list(kwargs.keys()),
)
try:
from google import genai
except ImportError as e:
raise ImportError(
"Please install gemini Python sdk with "
"`pip install -q -U google-genai`",
) from e
super().__init__(model_name, stream)
self.client = genai.Client(
api_key=api_key,
**(client_kwargs or {}),
)
self.thinking_config = thinking_config
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**config_kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Call the Gemini model with the provided arguments.
Args:
messages (`list[dict[str, Any]]`):
A list of dictionaries, where `role` and `content` fields are
required.
tools (`list[dict] | None`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool name.
For more details, please refer to
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
For more details, please refer to
https://ai.google.dev/gemini-api/docs/structured-output
**config_kwargs (`Any`):
The keyword arguments for Gemini chat completions API.
"""
config: dict = {
"thinking_config": self.thinking_config,
**self.generate_kwargs,
**config_kwargs,
}
if tools:
config["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
# Handle deprecated "any" option with warning
if tool_choice == "any":
warnings.warn(
'"any" is deprecated and will be removed in a future '
"version.",
DeprecationWarning,
)
tool_choice = "required"
self._validate_tool_choice(tool_choice, tools)
config["tool_config"] = self._format_tool_choice(tool_choice)
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
config.pop("tools", None)
config.pop("tool_config", None)
config["response_mime_type"] = "application/json"
config["response_schema"] = structured_model
# Prepare the arguments for the Gemini API call
kwargs: dict[str, JSONSerializableObject] = {
"model": self.model_name,
"contents": messages,
"config": config,
}
start_datetime = datetime.now()
if self.stream:
response = await self.client.aio.models.generate_content_stream(
**kwargs,
)
return self._parse_gemini_stream_generation_response(
start_datetime,
response,
structured_model,
)
# non-streaming
response = await self.client.aio.models.generate_content(
**kwargs,
)
parsed_response = self._parse_gemini_generation_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_gemini_stream_generation_response(
self,
start_datetime: datetime,
response: AsyncIterator[GenerateContentResponse],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given a Gemini streaming generation response, extract the
content blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncIterator[GenerateContentResponse]`):
Gemini GenerateContentResponse async iterator to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
text = ""
thinking = ""
tool_calls: list[ToolUseBlock] = []
metadata: dict | None = None
async for chunk in response:
if (
chunk.candidates
and chunk.candidates[0].content
and chunk.candidates[0].content.parts
):
for part in chunk.candidates[0].content.parts:
if part.text:
if part.thought:
thinking += part.text
else:
text += part.text
if part.function_call:
keyword_args = part.function_call.args or {}
# .. note:: Gemini API always returns None for
# function_call.id, so we use thought_signature
# as the unique identifier for tool
# calls when available. That maybe
# infeasible someday, but Gemini
# requires the thought_signature for some
# llms like gemini-3-pro
if part.thought_signature:
call_id = base64.b64encode(
part.thought_signature,
).decode("utf-8")
else:
call_id = part.function_call.id
tool_calls.append(
ToolUseBlock(
type="tool_use",
id=call_id,
name=part.function_call.name,
input=keyword_args,
raw_input=json.dumps(
keyword_args,
ensure_ascii=False,
),
),
)
# Text parts
if text and structured_model:
metadata = _json_loads_with_repair(text)
usage = None
if chunk.usage_metadata:
usage = ChatUsage(
input_tokens=chunk.usage_metadata.prompt_token_count,
output_tokens=chunk.usage_metadata.total_token_count
- chunk.usage_metadata.prompt_token_count,
time=(datetime.now() - start_datetime).total_seconds(),
)
# The content blocks for the current chunk
content_blocks: list = []
if thinking:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=thinking,
),
)
if text:
content_blocks.append(
TextBlock(
type="text",
text=text,
),
)
yield ChatResponse(
content=content_blocks + tool_calls,
usage=usage,
metadata=metadata,
)
def _parse_gemini_generation_response(
self,
start_datetime: datetime,
response: GenerateContentResponse,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given a Gemini chat completion response object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`GenerateContentResponse`):
The Gemini generation response object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
metadata: dict | None = None
tool_calls: list = []
if (
response.candidates
and response.candidates[0].content
and response.candidates[0].content.parts
):
for part in response.candidates[0].content.parts:
if part.text:
if part.thought:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=part.text,
),
)
else:
content_blocks.append(
TextBlock(
type="text",
text=part.text,
),
)
if part.function_call:
keyword_args = part.function_call.args or {}
# .. note:: Gemini API always returns None for
# function_call.id, so we use thought_signature
# as the unique identifier for tool
# calls when available. That maybe infeasible
# someday, but Gemini requires the thought_signature
# for some llms like gemini-3-pro
if part.thought_signature:
call_id = base64.b64encode(
part.thought_signature,
).decode("utf-8")
else:
call_id = part.function_call.id
tool_calls.append(
ToolUseBlock(
type="tool_use",
id=call_id,
name=part.function_call.name,
input=keyword_args,
raw_input=json.dumps(
keyword_args,
ensure_ascii=False,
),
),
)
# For the structured output case
if response.text and structured_model:
metadata = _json_loads_with_repair(response.text)
if response.usage_metadata:
usage = ChatUsage(
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.total_token_count
- response.usage_metadata.prompt_token_count,
time=(datetime.now() - start_datetime).total_seconds(),
)
else:
usage = None
return ChatResponse(
content=content_blocks + tool_calls,
usage=usage,
metadata=metadata,
)
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schema into required format for Gemini API.
.. note:: Gemini API does not support `$defs` and `$ref` in JSON
schemas. This function resolves all `$ref` references by inlining the
referenced definitions, producing a self-contained schema without
any references.
Args:
schemas (`dict[str, Any]`):
The tools JSON schemas.
Returns:
List[Dict[str, Any]]:
A list containing a dictionary with the
"function_declarations" key, which maps to a list of
function definitions.
Example:
.. code-block:: python
:caption: Example tool schemas of Gemini API
# Input JSON schema
schemas = [
{
'type': 'function',
'function': {
'name': 'execute_shell_command',
'description': 'xxx',
'parameters': {
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'xxx.'
},
'timeout': {
'type': 'integer',
'default': 300
}
},
'required': ['command']
}
}
}
]
# Output format (Gemini API expected):
[
{
'function_declarations': [
{
'name': 'execute_shell_command',
'description': 'xxx.',
'parameters': {
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'xxx.'
},
'timeout': {
'type': 'integer',
'default': 300
}
},
'required': ['command']
}
}
]
}
]
"""
function_declarations = []
for schema in schemas:
if "function" not in schema:
continue
func = schema["function"].copy()
# Flatten the parameters schema to resolve $ref references
if "parameters" in func:
func["parameters"] = _flatten_json_schema(func["parameters"])
function_declarations.append(func)
return [{"function_declarations": function_declarations}]
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "required"] | str | None,
) -> dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "required"] | str | None`, \
default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool name.
For more details, please refer to
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
mode_mapping = {
"auto": "AUTO",
"none": "NONE",
"required": "ANY",
}
mode = mode_mapping.get(tool_choice)
if mode:
return {"function_calling_config": {"mode": mode}}
return {
"function_calling_config": {
"mode": "ANY",
"allowed_function_names": [tool_choice],
},
}

View File

@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
"""The chat model base class."""
from abc import abstractmethod
from typing import AsyncGenerator, Any
from ._model_response import ChatResponse
_TOOL_CHOICE_MODES = ["auto", "none", "required"]
class ChatModelBase:
"""Base class for chat models."""
model_name: str
"""The model name"""
stream: bool
"""Is the model output streaming or not"""
def __init__(
self,
model_name: str,
stream: bool,
) -> None:
"""Initialize the chat model base class.
Args:
model_name (`str`):
The name of the model
stream (`bool`):
Whether the model output is streaming or not
"""
self.model_name = model_name
self.stream = stream
@abstractmethod
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
pass
def _validate_tool_choice(
self,
tool_choice: str,
tools: list[dict] | None,
) -> None:
"""
Validate tool_choice parameter.
Args:
tool_choice (`str`):
Tool choice mode or function name
tools (`list[dict] | None`):
Available tools list
Raises:
TypeError: If tool_choice is not string
ValueError: If tool_choice is invalid
"""
if not isinstance(tool_choice, str):
raise TypeError(
f"tool_choice must be str, got {type(tool_choice)}",
)
if tool_choice in _TOOL_CHOICE_MODES:
return
available_functions = [tool["function"]["name"] for tool in tools]
if tool_choice not in available_functions:
all_options = _TOOL_CHOICE_MODES + available_functions
raise ValueError(
f"Invalid tool_choice '{tool_choice}'. "
f"Available options: {', '.join(sorted(all_options))}",
)

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""The model response module."""
from dataclasses import dataclass, field
from typing import Literal, Sequence
from ._model_usage import ChatUsage
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..message import (
TextBlock,
ToolUseBlock,
ThinkingBlock,
AudioBlock,
)
from ..types import JSONSerializableObject
@dataclass
class ChatResponse(DictMixin):
"""The response of chat models."""
content: Sequence[TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock]
"""The content of the chat response, which can include text blocks,
tool use blocks, or thinking blocks."""
id: str = field(default_factory=lambda: _get_timestamp(True))
"""The unique identifier formatter """
created_at: str = field(default_factory=_get_timestamp)
"""When the response was created"""
type: Literal["chat"] = field(default_factory=lambda: "chat")
"""The type of the response, which is always 'chat'."""
usage: ChatUsage | None = field(default_factory=lambda: None)
"""The usage information of the chat response, if available."""
metadata: dict[str, JSONSerializableObject] | None = field(
default_factory=lambda: None,
)
"""The metadata of the chat response"""

View File

@@ -0,0 +1,26 @@
# -*- coding: utf-8 -*-
"""The model usage class in agentscope."""
from dataclasses import dataclass, field
from typing import Literal, Any
from .._utils._mixin import DictMixin
@dataclass
class ChatUsage(DictMixin):
"""The usage of a chat model API invocation."""
input_tokens: int
"""The number of input tokens."""
output_tokens: int
"""The number of output tokens."""
time: float
"""The time used in seconds."""
type: Literal["chat"] = field(default_factory=lambda: "chat")
"""The type of the usage, must be `chat`."""
metadata: dict[str, Any] | None = field(default_factory=lambda: None)
"""The metadata of the usage."""

View File

@@ -0,0 +1,355 @@
# -*- coding: utf-8 -*-
"""Model wrapper for Ollama models."""
import json
from datetime import datetime
from typing import (
Any,
TYPE_CHECKING,
List,
AsyncGenerator,
AsyncIterator,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from . import ChatResponse
from ._model_base import ChatModelBase
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from ollama._types import ChatResponse as OllamaChatResponse
else:
OllamaChatResponse = "ollama._types.ChatResponse"
class OllamaChatModel(ChatModelBase):
"""The Ollama chat model class in agentscope."""
def __init__(
self,
model_name: str,
stream: bool = False,
options: dict = None,
keep_alive: str = "5m",
enable_thinking: bool | None = None,
host: str | None = None,
client_kwargs: dict[str, JSONSerializableObject] | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Ollama chat model.
Args:
model_name (`str`):
The name of the model.
stream (`bool`, default `True`):
Streaming mode or not.
options (`dict`, default `None`):
Additional parameters to pass to the Ollama API. These can
include temperature etc.
keep_alive (`str`, default `"5m"`):
Duration to keep the model loaded in memory. The format is a
number followed by a unit suffix (s for seconds, m for minutes
, h for hours).
enable_thinking (`bool | None`, default `None`)
Whether enable thinking or not, only for models such as qwen3,
deepseek-r1, etc. For more details, please refer to
https://ollama.com/search?c=thinking
host (`str | None`, default `None`):
The host address of the Ollama server. If None, uses the
default address (typically http://localhost:11434).
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments to initialize the Ollama client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Ollama API generation.
**kwargs (`Any`):
Additional keyword arguments to pass to the base chat model
class.
"""
try:
import ollama
except ImportError as e:
raise ImportError(
"The package ollama is not found. Please install it by "
'running command `pip install "ollama>=0.1.7"`',
) from e
super().__init__(model_name, stream)
self.client = ollama.AsyncClient(
host=host,
**(client_kwargs or {}),
**kwargs,
)
self.options = options
self.keep_alive = keep_alive
self.think = enable_thinking
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from Ollama chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Ollama doesn't support `tool_choice` argument yet.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
**kwargs (`Any`):
The keyword arguments for Ollama chat completions API,
e.g. `think`etc. Please refer to the Ollama API
documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the Ollama chat completions API.
"""
kwargs = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"options": self.options,
"keep_alive": self.keep_alive,
**self.generate_kwargs,
**kwargs,
}
if self.think is not None and "think" not in kwargs:
kwargs["think"] = self.think
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
logger.warning("Ollama does not support tool_choice yet, ignored.")
if structured_model:
kwargs["format"] = structured_model.model_json_schema()
start_datetime = datetime.now()
response = await self.client.chat(**kwargs)
if self.stream:
return self._parse_ollama_stream_completion_response(
start_datetime,
response,
structured_model,
)
parsed_response = await self._parse_ollama_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_ollama_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncIterator[OllamaChatResponse],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Ollama streaming completion response, extract the
content blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncIterator[OllamaChatResponse]`):
Ollama streaming response async iterator to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
AsyncGenerator[ChatResponse, None]:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
accumulated_text = ""
acc_thinking_content = ""
tool_calls = OrderedDict() # Store tool calls
metadata: dict | None = None
async for chunk in response:
# Handle text content
msg = chunk.message
acc_thinking_content += msg.thinking or ""
accumulated_text += msg.content or ""
# Handle tool calls
for idx, tool_call in enumerate(msg.tool_calls or []):
function = tool_call.function
tool_id = f"{idx}_{function.name}"
tool_calls[tool_id] = {
"type": "tool_use",
"id": tool_id,
"name": function.name,
"input": function.arguments,
"raw_input": json.dumps(function.arguments),
}
# Calculate usage statistics
current_time = (datetime.now() - start_datetime).total_seconds()
usage = ChatUsage(
input_tokens=getattr(chunk, "prompt_eval_count", 0) or 0,
output_tokens=getattr(chunk, "eval_count", 0) or 0,
time=current_time,
)
# Create content blocks
contents: list = []
if acc_thinking_content:
contents.append(
ThinkingBlock(
type="thinking",
thinking=acc_thinking_content,
),
)
if accumulated_text:
contents.append(TextBlock(type="text", text=accumulated_text))
if structured_model:
metadata = _json_loads_with_repair(accumulated_text)
# Add tool call blocks
for tool_call in tool_calls.values():
try:
input_data = tool_call["input"]
if isinstance(input_data, str):
input_data = _json_loads_with_repair(input_data)
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_data,
),
)
except Exception as e:
print(f"Error parsing tool call input: {e}")
# Generate response when there's new content or at final chunk
if chunk.done or contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
async def _parse_ollama_completion_response(
self,
start_datetime: datetime,
response: OllamaChatResponse,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Ollama chat completion response object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`OllamaChatResponse`):
Ollama OllamaChatResponse object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`ChatResponse`:
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
metadata: dict | None = None
if response.message.thinking:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=response.message.thinking,
),
)
if response.message.content:
content_blocks.append(
TextBlock(
type="text",
text=response.message.content,
),
)
if structured_model:
metadata = _json_loads_with_repair(
response.message.content,
)
for idx, tool_call in enumerate(response.message.tool_calls or []):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=f"{idx}_{tool_call.function.name}",
name=tool_call.function.name,
input=tool_call.function.arguments,
raw_input=json.dumps(tool_call.function.arguments),
),
)
usage = None
if "prompt_eval_count" in response and "eval_count" in response:
usage = ChatUsage(
input_tokens=response.get("prompt_eval_count", 0),
output_tokens=response.get("eval_count", 0),
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schemas to the Ollama format."""
return schemas

View File

@@ -0,0 +1,650 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""OpenAI Chat model class."""
import copy
import json
import warnings
from datetime import datetime
from typing import (
Any,
TYPE_CHECKING,
List,
AsyncGenerator,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from . import ChatResponse
from ._model_base import ChatModelBase
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import (
ToolUseBlock,
TextBlock,
ThinkingBlock,
AudioBlock,
Base64Source,
)
from ..tracing import trace_llm
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from openai.types.chat import ChatCompletion
from openai import AsyncStream
else:
ChatCompletion = "openai.types.chat.ChatCompletion"
AsyncStream = "openai.types.chat.AsyncStream"
def _format_audio_data_for_qwen_omni(messages: list[dict]) -> None:
"""Qwen-omni uses OpenAI-compatible API but requires different audio
data format than OpenAI with "data:;base64," prefix.
Refer to `Qwen-omni documentation
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`_
for more details.
Args:
messages (`list[dict]`):
The list of message dictionaries from OpenAI formatter.
"""
for msg in messages:
if isinstance(msg.get("content"), list):
for block in msg["content"]:
if (
isinstance(block, dict)
and "input_audio" in block
and isinstance(block["input_audio"].get("data"), str)
):
if not block["input_audio"]["data"].startswith("http"):
block["input_audio"]["data"] = (
"data:;base64," + block["input_audio"]["data"]
)
class OpenAIChatModel(ChatModelBase):
"""The OpenAI chat model class."""
def __init__(
self,
model_name: str,
api_key: str | None = None,
stream: bool = True,
reasoning_effort: Literal["low", "medium", "high"] | None = None,
organization: str = None,
stream_tool_parsing: bool = True,
client_type: Literal["openai", "azure"] = "openai",
client_kwargs: dict[str, JSONSerializableObject] | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the openai client.
Args:
model_name (`str`, default `None`):
The name of the model to use in OpenAI API.
api_key (`str`, default `None`):
The API key for OpenAI API. If not specified, it will
be read from the environment variable `OPENAI_API_KEY`.
stream (`bool`, default `True`):
Whether to use streaming output or not.
reasoning_effort (`Literal["low", "medium", "high"] | None`, \
optional):
Reasoning effort, supported for o3, o4, etc. Please refer to
`OpenAI documentation
<https://platform.openai.com/docs/guides/reasoning?api-mode=chat>`_
for more details.
organization (`str`, default `None`):
The organization ID for OpenAI API. If not specified, it will
be read from the environment variable `OPENAI_ORGANIZATION`.
stream_tool_parsing (`bool`, default to `True`):
Whether to parse incomplete tool use JSON during streaming
with auto-repair. If True, partial JSON (e.g., `'{"a": "x'`)
is repaired to valid dicts ({"a": "x"}) in real-time for
immediate tool function input. Otherwise, the input field
remains {} until the final chunk arrives.
client_type (`Literal["openai", "azure"]`, default `openai`):
Selects which OpenAI-compatible client to initialize.
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments to initialize the OpenAI client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in OpenAI API generation,
e.g. `temperature`, `seed`.
**kwargs (`Any`):
Additional keyword arguments.
"""
# Handle deprecated client_args parameter from kwargs
client_args = kwargs.pop("client_args", None)
if client_args is not None and client_kwargs is not None:
raise ValueError(
"Cannot specify both 'client_args' and 'client_kwargs'. "
"Please use only 'client_kwargs' (client_args is deprecated).",
)
if client_args is not None:
logger.warning(
"The parameter 'client_args' is deprecated and will be "
"removed in a future version. Please use 'client_kwargs' "
"instead. Automatically converting 'client_args' to "
"'client_kwargs'.",
)
client_kwargs = client_args
if kwargs:
logger.warning(
"Unknown keyword arguments: %s. These will be ignored.",
list(kwargs.keys()),
)
super().__init__(model_name, stream)
import openai
if client_type not in ("openai", "azure"):
raise ValueError(
"Invalid client_type. Supported values: 'openai', 'azure'.",
)
if client_type == "azure":
self.client = openai.AsyncAzureOpenAI(
api_key=api_key,
organization=organization,
**(client_kwargs or {}),
)
else:
self.client = openai.AsyncClient(
api_key=api_key,
organization=organization,
**(client_kwargs or {}),
)
self.reasoning_effort = reasoning_effort
self.stream_tool_parsing = stream_tool_parsing
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from OpenAI chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool
name. For more details, please refer to
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
For more details, please refer to the `official document
<https://platform.openai.com/docs/guides/structured-outputs>`_
**kwargs (`Any`):
The keyword arguments for OpenAI chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
refer to the OpenAI API documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the OpenAI chat completions API.
"""
# checking messages
if not isinstance(messages, list):
raise ValueError(
"OpenAI `messages` field expected type `list`, "
f"got `{type(messages)}` instead.",
)
if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for OpenAI API.",
)
# Qwen-omni requires different base64 audio format from openai
if "omni" in self.model_name.lower():
_format_audio_data_for_qwen_omni(messages)
kwargs = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
**self.generate_kwargs,
**kwargs,
}
if self.reasoning_effort and "reasoning_effort" not in kwargs:
kwargs["reasoning_effort"] = self.reasoning_effort
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
# Handle deprecated "any" option with warning
if tool_choice == "any":
warnings.warn(
'"any" is deprecated and will be removed in a future '
"version.",
DeprecationWarning,
)
tool_choice = "required"
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if self.stream:
kwargs["stream_options"] = {"include_usage": True}
start_datetime = datetime.now()
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
kwargs.pop("stream", None)
kwargs.pop("tools", None)
kwargs.pop("tool_choice", None)
kwargs["response_format"] = structured_model
if not self.stream:
response = await self.client.chat.completions.parse(**kwargs)
else:
response = self.client.chat.completions.stream(**kwargs)
return self._parse_openai_stream_response(
start_datetime,
response,
structured_model,
)
else:
response = await self.client.chat.completions.create(**kwargs)
if self.stream:
return self._parse_openai_stream_response(
start_datetime,
response,
structured_model,
)
# Non-streaming response
parsed_response = self._parse_openai_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
# pylint: disable=too-many-statements
async def _parse_openai_stream_response(
self,
start_datetime: datetime,
response: AsyncStream,
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an OpenAI streaming completion response, extract the content
blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncStream`):
OpenAI AsyncStream object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in
the streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
usage, res = None, None
text = ""
thinking = ""
audio = ""
tool_calls = OrderedDict()
last_input_objs = {} # Store last input_obj for each tool_call
metadata: dict | None = None
contents: List[
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
] = []
last_contents = None
async with response as stream:
async for item in stream:
if structured_model:
if item.type != "chunk":
continue
chunk = item.chunk
else:
chunk = item
if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
metadata=chunk.usage,
)
if not chunk.choices:
if usage and contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
continue
choice = chunk.choices[0]
thinking += (
getattr(choice.delta, "reasoning_content", None) or ""
)
text += getattr(choice.delta, "content", None) or ""
if (
hasattr(choice.delta, "audio")
and "data" in choice.delta.audio
):
audio += choice.delta.audio["data"]
if (
hasattr(choice.delta, "audio")
and "transcript" in choice.delta.audio
):
text += choice.delta.audio["transcript"]
for tool_call in (
getattr(choice.delta, "tool_calls", None) or []
):
if tool_call.index in tool_calls:
if tool_call.function.arguments is not None:
tool_calls[tool_call.index][
"input"
] += tool_call.function.arguments
else:
tool_calls[tool_call.index] = {
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": tool_call.function.arguments or "",
}
contents = []
if thinking:
contents.append(
ThinkingBlock(
type="thinking",
thinking=thinking,
),
)
if audio:
media_type = self.generate_kwargs.get("audio", {}).get(
"format",
"wav",
)
contents.append(
AudioBlock(
type="audio",
source=Base64Source(
data=audio,
media_type=f"audio/{media_type}",
type="base64",
),
),
)
if text:
contents.append(
TextBlock(
type="text",
text=text,
),
)
if structured_model:
metadata = _json_loads_with_repair(text)
for tool_call in tool_calls.values():
input_str = tool_call["input"]
tool_id = tool_call["id"]
# If parsing the tool input in streaming mode
if self.stream_tool_parsing:
repaired_input = _json_loads_with_repair(
input_str or "{}",
)
# If the new repaired input is shorter than one in the
# last chunk, use the last one to avoid regression
last_input = last_input_objs.get(tool_id, {})
if len(json.dumps(last_input)) > len(
json.dumps(repaired_input),
):
repaired_input = last_input
last_input_objs[tool_id] = repaired_input
else:
# Otherwise, keep input as empty dict until the final
# chunk
repaired_input = {}
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_id,
name=tool_call["name"],
input=repaired_input,
raw_input=input_str,
),
)
if contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
last_contents = copy.deepcopy(contents)
# If stream_tool_parsing is False, yield last contents
if not self.stream_tool_parsing and tool_calls and last_contents:
metadata = None
# Update tool use blocks in last_contents inplace
for block in last_contents:
if block.get("type") == "tool_use":
block["input"] = input_obj = _json_loads_with_repair(
str(block.get("raw_input") or "{}"),
)
if structured_model:
metadata = input_obj
yield ChatResponse(
content=last_contents,
usage=usage,
metadata=metadata,
)
def _parse_openai_completion_response(
self,
start_datetime: datetime,
response: ChatCompletion,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an OpenAI chat completion response object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`ChatCompletion`):
OpenAI ChatCompletion object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
] = []
metadata: dict | None = None
if response.choices:
choice = response.choices[0]
if (
hasattr(choice.message, "reasoning_content")
and choice.message.reasoning_content is not None
):
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=response.choices[0].message.reasoning_content,
),
)
if choice.message.content:
content_blocks.append(
TextBlock(
type="text",
text=response.choices[0].message.content,
),
)
if choice.message.audio:
media_type = self.generate_kwargs.get("audio", {}).get(
"format",
"mp3",
)
content_blocks.append(
AudioBlock(
type="audio",
source=Base64Source(
data=choice.message.audio.data,
media_type=f"audio/{media_type}",
type="base64",
),
),
)
if choice.message.audio.transcript:
content_blocks.append(
TextBlock(
type="text",
text=choice.message.audio.transcript,
),
)
for tool_call in choice.message.tool_calls or []:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=_json_loads_with_repair(
tool_call.function.arguments,
),
),
)
if structured_model:
metadata = choice.message.parsed.model_dump()
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
metadata=response.usage,
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schemas to the OpenAI format."""
return schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "required"] | str | None,
) -> str | dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "required", or specific tool name.
For more details, please refer to
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
mode_mapping = {
"auto": "auto",
"none": "none",
"required": "required",
}
if tool_choice in mode_mapping:
return mode_mapping[tool_choice]
return {"type": "function", "function": {"name": tool_choice}}

View File

@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
"""A model class for RL Training with Trinity-RFT."""
from typing import (
Optional,
TYPE_CHECKING,
)
from typing_extensions import deprecated
from ._openai_model import OpenAIChatModel
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from openai import AsyncOpenAI
else:
AsyncOpenAI = "openai.AsyncOpenAI"
@deprecated(
"TrinityChatModel is deprecated. Please use OpenAIChatModel directly.",
)
class TrinityChatModel(OpenAIChatModel):
"""A model class for RL Training with Trinity-RFT."""
def __init__(
self,
openai_async_client: AsyncOpenAI,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
enable_thinking: Optional[bool] = None,
) -> None:
"""Initialize the Trinity model class.
Args:
openai_async_client (`AsyncOpenAI`):
The OpenAI async client instance provided by Trinity-RFT.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
Additional keyword arguments to pass to the model's generate
method. Defaults to None.
enable_thinking (`bool`, optional):
Whether to enable the model's thinking capability. Only
applicable for Qwen3 series models. Defaults to None.
"""
model_name = getattr(openai_async_client, "model_path", None)
if model_name is None:
raise ValueError(
"The provided openai_async_client does not have a "
"`model_path` attribute. Please ensure you are using "
"the instance provided by Trinity-RFT.",
)
super().__init__(
model_name=model_name,
api_key="EMPTY",
generate_kwargs=generate_kwargs,
stream=False, # RL training does not support streaming
)
if enable_thinking is not None:
if "chat_template_kwargs" not in self.generate_kwargs:
self.generate_kwargs["chat_template_kwargs"] = {}
assert isinstance(
self.generate_kwargs["chat_template_kwargs"],
dict,
), "chat_template_kwargs must be a dictionary."
self.generate_kwargs["chat_template_kwargs"][
"enable_thinking"
] = enable_thinking
# change the client instance to the provided one
self.client = openai_async_client