Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
647 lines
24 KiB
Python
647 lines
24 KiB
Python
# -*- coding: utf-8 -*-
|
|
# mypy: disable-error-code="dict-item"
|
|
"""The Google Gemini model in agentscope."""
|
|
import base64
|
|
import copy
|
|
import warnings
|
|
from datetime import datetime
|
|
import json
|
|
from typing import (
|
|
AsyncGenerator,
|
|
Any,
|
|
TYPE_CHECKING,
|
|
AsyncIterator,
|
|
Literal,
|
|
Type,
|
|
List,
|
|
)
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from .._logging import logger
|
|
from .._utils._common import _json_loads_with_repair
|
|
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
|
|
from ._model_usage import ChatUsage
|
|
from ._model_base import ChatModelBase
|
|
from ._model_response import ChatResponse
|
|
from ..tracing import trace_llm
|
|
from ..types import JSONSerializableObject
|
|
|
|
if TYPE_CHECKING:
|
|
from google.genai.types import GenerateContentResponse
|
|
else:
|
|
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
|
|
|
|
|
|
def _flatten_json_schema(schema: dict) -> dict:
|
|
"""Flatten a JSON schema by resolving all $ref references.
|
|
|
|
.. note::
|
|
Gemini API does not support `$defs` and `$ref` in JSON schemas.
|
|
This function resolves all `$ref` references by inlining the
|
|
referenced definitions, producing a self-contained schema without
|
|
any references.
|
|
|
|
Args:
|
|
schema (`dict`):
|
|
The JSON schema that may contain `$defs` and `$ref` references.
|
|
|
|
Returns:
|
|
`dict`:
|
|
A flattened JSON schema with all references resolved inline.
|
|
"""
|
|
# Deep copy to avoid modifying the original schema
|
|
schema = copy.deepcopy(schema)
|
|
|
|
# Extract $defs if present
|
|
defs = schema.pop("$defs", {})
|
|
|
|
def _resolve_ref(obj: Any, visited: set | None = None) -> Any:
|
|
"""Recursively resolve $ref references in the schema."""
|
|
if visited is None:
|
|
visited = set()
|
|
|
|
if not isinstance(obj, dict):
|
|
if isinstance(obj, list):
|
|
return [_resolve_ref(item, visited.copy()) for item in obj]
|
|
return obj
|
|
|
|
# Handle $ref
|
|
if "$ref" in obj:
|
|
ref_path = obj["$ref"]
|
|
# Extract definition name from "#/$defs/DefinitionName"
|
|
if ref_path.startswith("#/$defs/"):
|
|
def_name = ref_path[len("#/$defs/") :]
|
|
|
|
# Prevent infinite recursion for circular references
|
|
if def_name in visited:
|
|
logger.warning(
|
|
"Circular reference detected for '%s' in tool schema",
|
|
def_name,
|
|
)
|
|
return {
|
|
"type": "object",
|
|
"description": f"(circular: {def_name})",
|
|
}
|
|
|
|
visited.add(def_name)
|
|
|
|
if def_name in defs:
|
|
# Recursively resolve any nested refs in the definition
|
|
resolved = _resolve_ref(
|
|
defs[def_name],
|
|
visited.copy(),
|
|
)
|
|
# Merge any additional properties from the original object
|
|
# (excluding $ref itself)
|
|
for key, value in obj.items():
|
|
if key != "$ref":
|
|
resolved[key] = _resolve_ref(value, visited.copy())
|
|
return resolved
|
|
|
|
# If we can't resolve the ref, return as-is (shouldn't happen)
|
|
return obj
|
|
|
|
# Recursively process all nested objects
|
|
result = {}
|
|
for key, value in obj.items():
|
|
result[key] = _resolve_ref(value, visited.copy())
|
|
|
|
return result
|
|
|
|
return _resolve_ref(schema)
|
|
|
|
|
|
class GeminiChatModel(ChatModelBase):
|
|
"""The Google Gemini chat model class in agentscope."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str,
|
|
api_key: str,
|
|
stream: bool = True,
|
|
thinking_config: dict | None = None,
|
|
client_kwargs: dict[str, JSONSerializableObject] | None = None,
|
|
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize the Gemini chat model.
|
|
|
|
Args:
|
|
model_name (`str`):
|
|
The name of the Gemini model to use, e.g. "gemini-2.5-flash".
|
|
api_key (`str`):
|
|
The API key for Google Gemini.
|
|
stream (`bool`, default `True`):
|
|
Whether to use streaming output or not.
|
|
thinking_config (`dict | None`, optional):
|
|
Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc.
|
|
Refer to https://ai.google.dev/gemini-api/docs/thinking for
|
|
more details.
|
|
|
|
.. code-block:: python
|
|
:caption: Example of thinking_config
|
|
|
|
{
|
|
"include_thoughts": True, # enable thoughts or not
|
|
"thinking_budget": 1024 # Max tokens for reasoning
|
|
}
|
|
|
|
client_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
|
optional):
|
|
The extra keyword arguments to initialize the Gemini client.
|
|
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
|
optional):
|
|
The extra keyword arguments used in Gemini API generation,
|
|
e.g. `temperature`, `seed`.
|
|
**kwargs (`Any`):
|
|
Additional keyword arguments.
|
|
"""
|
|
|
|
# Handle deprecated client_args parameter from kwargs
|
|
client_args = kwargs.pop("client_args", None)
|
|
if client_args is not None and client_kwargs is not None:
|
|
raise ValueError(
|
|
"Cannot specify both 'client_args' and 'client_kwargs'. "
|
|
"Please use only 'client_kwargs' (client_args is deprecated).",
|
|
)
|
|
|
|
if client_args is not None:
|
|
logger.warning(
|
|
"The parameter 'client_args' is deprecated and will be "
|
|
"removed in a future version. Please use 'client_kwargs' "
|
|
"instead. Automatically converting 'client_args' to "
|
|
"'client_kwargs'.",
|
|
)
|
|
client_kwargs = client_args
|
|
|
|
if kwargs:
|
|
logger.warning(
|
|
"Unknown keyword arguments: %s. These will be ignored.",
|
|
list(kwargs.keys()),
|
|
)
|
|
|
|
try:
|
|
from google import genai
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install gemini Python sdk with "
|
|
"`pip install -q -U google-genai`",
|
|
) from e
|
|
|
|
super().__init__(model_name, stream)
|
|
|
|
self.client = genai.Client(
|
|
api_key=api_key,
|
|
**(client_kwargs or {}),
|
|
)
|
|
self.thinking_config = thinking_config
|
|
self.generate_kwargs = generate_kwargs or {}
|
|
|
|
@trace_llm
|
|
async def __call__(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
|
structured_model: Type[BaseModel] | None = None,
|
|
**config_kwargs: Any,
|
|
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
|
"""Call the Gemini model with the provided arguments.
|
|
|
|
Args:
|
|
messages (`list[dict[str, Any]]`):
|
|
A list of dictionaries, where `role` and `content` fields are
|
|
required.
|
|
tools (`list[dict] | None`, default `None`):
|
|
The tools JSON schemas that the model can use.
|
|
tool_choice (`Literal["auto", "none", "required"] | str \
|
|
| None`, default `None`):
|
|
Controls which (if any) tool is called by the model.
|
|
Can be "auto", "none", "required", or specific tool name.
|
|
For more details, please refer to
|
|
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
|
|
structured_model (`Type[BaseModel] | None`, default `None`):
|
|
A Pydantic BaseModel class that defines the expected structure
|
|
for the model's output.
|
|
|
|
.. note:: When `structured_model` is specified,
|
|
both `tools` and `tool_choice` parameters are ignored,
|
|
and the model will only perform structured output
|
|
generation without calling any other tools.
|
|
|
|
For more details, please refer to
|
|
https://ai.google.dev/gemini-api/docs/structured-output
|
|
|
|
**config_kwargs (`Any`):
|
|
The keyword arguments for Gemini chat completions API.
|
|
"""
|
|
|
|
config: dict = {
|
|
"thinking_config": self.thinking_config,
|
|
**self.generate_kwargs,
|
|
**config_kwargs,
|
|
}
|
|
|
|
if tools:
|
|
config["tools"] = self._format_tools_json_schemas(tools)
|
|
|
|
if tool_choice:
|
|
# Handle deprecated "any" option with warning
|
|
if tool_choice == "any":
|
|
warnings.warn(
|
|
'"any" is deprecated and will be removed in a future '
|
|
"version.",
|
|
DeprecationWarning,
|
|
)
|
|
tool_choice = "required"
|
|
self._validate_tool_choice(tool_choice, tools)
|
|
config["tool_config"] = self._format_tool_choice(tool_choice)
|
|
|
|
if structured_model:
|
|
if tools or tool_choice:
|
|
logger.warning(
|
|
"structured_model is provided. Both 'tools' and "
|
|
"'tool_choice' parameters will be overridden and "
|
|
"ignored. The model will only perform structured output "
|
|
"generation without calling any other tools.",
|
|
)
|
|
config.pop("tools", None)
|
|
config.pop("tool_config", None)
|
|
config["response_mime_type"] = "application/json"
|
|
config["response_schema"] = structured_model
|
|
|
|
# Prepare the arguments for the Gemini API call
|
|
kwargs: dict[str, JSONSerializableObject] = {
|
|
"model": self.model_name,
|
|
"contents": messages,
|
|
"config": config,
|
|
}
|
|
|
|
start_datetime = datetime.now()
|
|
if self.stream:
|
|
response = await self.client.aio.models.generate_content_stream(
|
|
**kwargs,
|
|
)
|
|
|
|
return self._parse_gemini_stream_generation_response(
|
|
start_datetime,
|
|
response,
|
|
structured_model,
|
|
)
|
|
|
|
# non-streaming
|
|
response = await self.client.aio.models.generate_content(
|
|
**kwargs,
|
|
)
|
|
|
|
parsed_response = self._parse_gemini_generation_response(
|
|
start_datetime,
|
|
response,
|
|
structured_model,
|
|
)
|
|
|
|
return parsed_response
|
|
|
|
async def _parse_gemini_stream_generation_response(
|
|
self,
|
|
start_datetime: datetime,
|
|
response: AsyncIterator[GenerateContentResponse],
|
|
structured_model: Type[BaseModel] | None = None,
|
|
) -> AsyncGenerator[ChatResponse, None]:
|
|
"""Given a Gemini streaming generation response, extract the
|
|
content blocks and usages from it and yield ChatResponse objects.
|
|
|
|
Args:
|
|
start_datetime (`datetime`):
|
|
The start datetime of the response generation.
|
|
response (`AsyncIterator[GenerateContentResponse]`):
|
|
Gemini GenerateContentResponse async iterator to parse.
|
|
structured_model (`Type[BaseModel] | None`, default `None`):
|
|
A Pydantic BaseModel class that defines the expected structure
|
|
for the model's output.
|
|
|
|
Returns:
|
|
`AsyncGenerator[ChatResponse, None]`:
|
|
An async generator that yields ChatResponse objects containing
|
|
the content blocks and usage information for each chunk in the
|
|
streaming response.
|
|
|
|
.. note::
|
|
If `structured_model` is not `None`, the expected structured output
|
|
will be stored in the metadata of the `ChatResponse`.
|
|
"""
|
|
|
|
text = ""
|
|
thinking = ""
|
|
tool_calls: list[ToolUseBlock] = []
|
|
metadata: dict | None = None
|
|
async for chunk in response:
|
|
if (
|
|
chunk.candidates
|
|
and chunk.candidates[0].content
|
|
and chunk.candidates[0].content.parts
|
|
):
|
|
for part in chunk.candidates[0].content.parts:
|
|
if part.text:
|
|
if part.thought:
|
|
thinking += part.text
|
|
else:
|
|
text += part.text
|
|
|
|
if part.function_call:
|
|
keyword_args = part.function_call.args or {}
|
|
# .. note:: Gemini API always returns None for
|
|
# function_call.id, so we use thought_signature
|
|
# as the unique identifier for tool
|
|
# calls when available. That maybe
|
|
# infeasible someday, but Gemini
|
|
# requires the thought_signature for some
|
|
# llms like gemini-3-pro
|
|
|
|
if part.thought_signature:
|
|
call_id = base64.b64encode(
|
|
part.thought_signature,
|
|
).decode("utf-8")
|
|
else:
|
|
call_id = part.function_call.id
|
|
|
|
tool_calls.append(
|
|
ToolUseBlock(
|
|
type="tool_use",
|
|
id=call_id,
|
|
name=part.function_call.name,
|
|
input=keyword_args,
|
|
raw_input=json.dumps(
|
|
keyword_args,
|
|
ensure_ascii=False,
|
|
),
|
|
),
|
|
)
|
|
|
|
# Text parts
|
|
if text and structured_model:
|
|
metadata = _json_loads_with_repair(text)
|
|
|
|
usage = None
|
|
if chunk.usage_metadata:
|
|
usage = ChatUsage(
|
|
input_tokens=chunk.usage_metadata.prompt_token_count,
|
|
output_tokens=chunk.usage_metadata.total_token_count
|
|
- chunk.usage_metadata.prompt_token_count,
|
|
time=(datetime.now() - start_datetime).total_seconds(),
|
|
)
|
|
|
|
# The content blocks for the current chunk
|
|
content_blocks: list = []
|
|
|
|
if thinking:
|
|
content_blocks.append(
|
|
ThinkingBlock(
|
|
type="thinking",
|
|
thinking=thinking,
|
|
),
|
|
)
|
|
|
|
if text:
|
|
content_blocks.append(
|
|
TextBlock(
|
|
type="text",
|
|
text=text,
|
|
),
|
|
)
|
|
|
|
yield ChatResponse(
|
|
content=content_blocks + tool_calls,
|
|
usage=usage,
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _parse_gemini_generation_response(
|
|
self,
|
|
start_datetime: datetime,
|
|
response: GenerateContentResponse,
|
|
structured_model: Type[BaseModel] | None = None,
|
|
) -> ChatResponse:
|
|
"""Given a Gemini chat completion response object, extract the content
|
|
blocks and usages from it.
|
|
|
|
Args:
|
|
start_datetime (`datetime`):
|
|
The start datetime of the response generation.
|
|
response (`GenerateContentResponse`):
|
|
The Gemini generation response object to parse.
|
|
structured_model (`Type[BaseModel] | None`, default `None`):
|
|
A Pydantic BaseModel class that defines the expected structure
|
|
for the model's output.
|
|
|
|
Returns:
|
|
ChatResponse (`ChatResponse`):
|
|
A ChatResponse object containing the content blocks and usage.
|
|
|
|
.. note::
|
|
If `structured_model` is not `None`, the expected structured output
|
|
will be stored in the metadata of the `ChatResponse`.
|
|
"""
|
|
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
|
metadata: dict | None = None
|
|
tool_calls: list = []
|
|
|
|
if (
|
|
response.candidates
|
|
and response.candidates[0].content
|
|
and response.candidates[0].content.parts
|
|
):
|
|
for part in response.candidates[0].content.parts:
|
|
if part.text:
|
|
if part.thought:
|
|
content_blocks.append(
|
|
ThinkingBlock(
|
|
type="thinking",
|
|
thinking=part.text,
|
|
),
|
|
)
|
|
else:
|
|
content_blocks.append(
|
|
TextBlock(
|
|
type="text",
|
|
text=part.text,
|
|
),
|
|
)
|
|
|
|
if part.function_call:
|
|
keyword_args = part.function_call.args or {}
|
|
# .. note:: Gemini API always returns None for
|
|
# function_call.id, so we use thought_signature
|
|
# as the unique identifier for tool
|
|
# calls when available. That maybe infeasible
|
|
# someday, but Gemini requires the thought_signature
|
|
# for some llms like gemini-3-pro
|
|
|
|
if part.thought_signature:
|
|
call_id = base64.b64encode(
|
|
part.thought_signature,
|
|
).decode("utf-8")
|
|
else:
|
|
call_id = part.function_call.id
|
|
|
|
tool_calls.append(
|
|
ToolUseBlock(
|
|
type="tool_use",
|
|
id=call_id,
|
|
name=part.function_call.name,
|
|
input=keyword_args,
|
|
raw_input=json.dumps(
|
|
keyword_args,
|
|
ensure_ascii=False,
|
|
),
|
|
),
|
|
)
|
|
|
|
# For the structured output case
|
|
if response.text and structured_model:
|
|
metadata = _json_loads_with_repair(response.text)
|
|
|
|
if response.usage_metadata:
|
|
usage = ChatUsage(
|
|
input_tokens=response.usage_metadata.prompt_token_count,
|
|
output_tokens=response.usage_metadata.total_token_count
|
|
- response.usage_metadata.prompt_token_count,
|
|
time=(datetime.now() - start_datetime).total_seconds(),
|
|
)
|
|
|
|
else:
|
|
usage = None
|
|
|
|
return ChatResponse(
|
|
content=content_blocks + tool_calls,
|
|
usage=usage,
|
|
metadata=metadata,
|
|
)
|
|
|
|
def _format_tools_json_schemas(
|
|
self,
|
|
schemas: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
"""Format the tools JSON schema into required format for Gemini API.
|
|
|
|
.. note:: Gemini API does not support `$defs` and `$ref` in JSON
|
|
schemas. This function resolves all `$ref` references by inlining the
|
|
referenced definitions, producing a self-contained schema without
|
|
any references.
|
|
|
|
Args:
|
|
schemas (`dict[str, Any]`):
|
|
The tools JSON schemas.
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]:
|
|
A list containing a dictionary with the
|
|
"function_declarations" key, which maps to a list of
|
|
function definitions.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
:caption: Example tool schemas of Gemini API
|
|
|
|
# Input JSON schema
|
|
schemas = [
|
|
{
|
|
'type': 'function',
|
|
'function': {
|
|
'name': 'execute_shell_command',
|
|
'description': 'xxx',
|
|
'parameters': {
|
|
'type': 'object',
|
|
'properties': {
|
|
'command': {
|
|
'type': 'string',
|
|
'description': 'xxx.'
|
|
},
|
|
'timeout': {
|
|
'type': 'integer',
|
|
'default': 300
|
|
}
|
|
},
|
|
'required': ['command']
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
# Output format (Gemini API expected):
|
|
[
|
|
{
|
|
'function_declarations': [
|
|
{
|
|
'name': 'execute_shell_command',
|
|
'description': 'xxx.',
|
|
'parameters': {
|
|
'type': 'object',
|
|
'properties': {
|
|
'command': {
|
|
'type': 'string',
|
|
'description': 'xxx.'
|
|
},
|
|
'timeout': {
|
|
'type': 'integer',
|
|
'default': 300
|
|
}
|
|
},
|
|
'required': ['command']
|
|
}
|
|
}
|
|
]
|
|
}
|
|
]
|
|
|
|
"""
|
|
function_declarations = []
|
|
for schema in schemas:
|
|
if "function" not in schema:
|
|
continue
|
|
func = schema["function"].copy()
|
|
# Flatten the parameters schema to resolve $ref references
|
|
if "parameters" in func:
|
|
func["parameters"] = _flatten_json_schema(func["parameters"])
|
|
function_declarations.append(func)
|
|
|
|
return [{"function_declarations": function_declarations}]
|
|
|
|
def _format_tool_choice(
|
|
self,
|
|
tool_choice: Literal["auto", "none", "required"] | str | None,
|
|
) -> dict | None:
|
|
"""Format tool_choice parameter for API compatibility.
|
|
|
|
Args:
|
|
tool_choice (`Literal["auto", "none", "required"] | str | None`, \
|
|
default `None`):
|
|
Controls which (if any) tool is called by the model.
|
|
Can be "auto", "none", "required", or specific tool name.
|
|
For more details, please refer to
|
|
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
|
|
|
|
Returns:
|
|
`dict | None`:
|
|
The formatted tool choice configuration dict, or None if
|
|
tool_choice is None.
|
|
"""
|
|
if tool_choice is None:
|
|
return None
|
|
|
|
mode_mapping = {
|
|
"auto": "AUTO",
|
|
"none": "NONE",
|
|
"required": "ANY",
|
|
}
|
|
mode = mode_mapping.get(tool_choice)
|
|
if mode:
|
|
return {"function_calling_config": {"mode": mode}}
|
|
return {
|
|
"function_calling_config": {
|
|
"mode": "ANY",
|
|
"allowed_function_names": [tool_choice],
|
|
},
|
|
}
|