# -*- coding: utf-8 -*- # mypy: disable-error-code="dict-item" """The Google Gemini model in agentscope.""" import base64 import copy import warnings from datetime import datetime import json from typing import ( AsyncGenerator, Any, TYPE_CHECKING, AsyncIterator, Literal, Type, List, ) from pydantic import BaseModel from .._logging import logger from .._utils._common import _json_loads_with_repair from ..message import ToolUseBlock, TextBlock, ThinkingBlock from ._model_usage import ChatUsage from ._model_base import ChatModelBase from ._model_response import ChatResponse from ..tracing import trace_llm from ..types import JSONSerializableObject if TYPE_CHECKING: from google.genai.types import GenerateContentResponse else: GenerateContentResponse = "google.genai.types.GenerateContentResponse" def _flatten_json_schema(schema: dict) -> dict: """Flatten a JSON schema by resolving all $ref references. .. note:: Gemini API does not support `$defs` and `$ref` in JSON schemas. This function resolves all `$ref` references by inlining the referenced definitions, producing a self-contained schema without any references. Args: schema (`dict`): The JSON schema that may contain `$defs` and `$ref` references. Returns: `dict`: A flattened JSON schema with all references resolved inline. """ # Deep copy to avoid modifying the original schema schema = copy.deepcopy(schema) # Extract $defs if present defs = schema.pop("$defs", {}) def _resolve_ref(obj: Any, visited: set | None = None) -> Any: """Recursively resolve $ref references in the schema.""" if visited is None: visited = set() if not isinstance(obj, dict): if isinstance(obj, list): return [_resolve_ref(item, visited.copy()) for item in obj] return obj # Handle $ref if "$ref" in obj: ref_path = obj["$ref"] # Extract definition name from "#/$defs/DefinitionName" if ref_path.startswith("#/$defs/"): def_name = ref_path[len("#/$defs/") :] # Prevent infinite recursion for circular references if def_name in visited: logger.warning( "Circular reference detected for '%s' in tool schema", def_name, ) return { "type": "object", "description": f"(circular: {def_name})", } visited.add(def_name) if def_name in defs: # Recursively resolve any nested refs in the definition resolved = _resolve_ref( defs[def_name], visited.copy(), ) # Merge any additional properties from the original object # (excluding $ref itself) for key, value in obj.items(): if key != "$ref": resolved[key] = _resolve_ref(value, visited.copy()) return resolved # If we can't resolve the ref, return as-is (shouldn't happen) return obj # Recursively process all nested objects result = {} for key, value in obj.items(): result[key] = _resolve_ref(value, visited.copy()) return result return _resolve_ref(schema) class GeminiChatModel(ChatModelBase): """The Google Gemini chat model class in agentscope.""" def __init__( self, model_name: str, api_key: str, stream: bool = True, thinking_config: dict | None = None, client_kwargs: dict[str, JSONSerializableObject] | None = None, generate_kwargs: dict[str, JSONSerializableObject] | None = None, **kwargs: Any, ) -> None: """Initialize the Gemini chat model. Args: model_name (`str`): The name of the Gemini model to use, e.g. "gemini-2.5-flash". api_key (`str`): The API key for Google Gemini. stream (`bool`, default `True`): Whether to use streaming output or not. thinking_config (`dict | None`, optional): Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc. Refer to https://ai.google.dev/gemini-api/docs/thinking for more details. .. code-block:: python :caption: Example of thinking_config { "include_thoughts": True, # enable thoughts or not "thinking_budget": 1024 # Max tokens for reasoning } client_kwargs (`dict[str, JSONSerializableObject] | None`, \ optional): The extra keyword arguments to initialize the Gemini client. generate_kwargs (`dict[str, JSONSerializableObject] | None`, \ optional): The extra keyword arguments used in Gemini API generation, e.g. `temperature`, `seed`. **kwargs (`Any`): Additional keyword arguments. """ # Handle deprecated client_args parameter from kwargs client_args = kwargs.pop("client_args", None) if client_args is not None and client_kwargs is not None: raise ValueError( "Cannot specify both 'client_args' and 'client_kwargs'. " "Please use only 'client_kwargs' (client_args is deprecated).", ) if client_args is not None: logger.warning( "The parameter 'client_args' is deprecated and will be " "removed in a future version. Please use 'client_kwargs' " "instead. Automatically converting 'client_args' to " "'client_kwargs'.", ) client_kwargs = client_args if kwargs: logger.warning( "Unknown keyword arguments: %s. These will be ignored.", list(kwargs.keys()), ) try: from google import genai except ImportError as e: raise ImportError( "Please install gemini Python sdk with " "`pip install -q -U google-genai`", ) from e super().__init__(model_name, stream) self.client = genai.Client( api_key=api_key, **(client_kwargs or {}), ) self.thinking_config = thinking_config self.generate_kwargs = generate_kwargs or {} @trace_llm async def __call__( self, messages: list[dict], tools: list[dict] | None = None, tool_choice: Literal["auto", "none", "required"] | str | None = None, structured_model: Type[BaseModel] | None = None, **config_kwargs: Any, ) -> ChatResponse | AsyncGenerator[ChatResponse, None]: """Call the Gemini model with the provided arguments. Args: messages (`list[dict[str, Any]]`): A list of dictionaries, where `role` and `content` fields are required. tools (`list[dict] | None`, default `None`): The tools JSON schemas that the model can use. tool_choice (`Literal["auto", "none", "required"] | str \ | None`, default `None`): Controls which (if any) tool is called by the model. Can be "auto", "none", "required", or specific tool name. For more details, please refer to https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes structured_model (`Type[BaseModel] | None`, default `None`): A Pydantic BaseModel class that defines the expected structure for the model's output. .. note:: When `structured_model` is specified, both `tools` and `tool_choice` parameters are ignored, and the model will only perform structured output generation without calling any other tools. For more details, please refer to https://ai.google.dev/gemini-api/docs/structured-output **config_kwargs (`Any`): The keyword arguments for Gemini chat completions API. """ config: dict = { "thinking_config": self.thinking_config, **self.generate_kwargs, **config_kwargs, } if tools: config["tools"] = self._format_tools_json_schemas(tools) if tool_choice: # Handle deprecated "any" option with warning if tool_choice == "any": warnings.warn( '"any" is deprecated and will be removed in a future ' "version.", DeprecationWarning, ) tool_choice = "required" self._validate_tool_choice(tool_choice, tools) config["tool_config"] = self._format_tool_choice(tool_choice) if structured_model: if tools or tool_choice: logger.warning( "structured_model is provided. Both 'tools' and " "'tool_choice' parameters will be overridden and " "ignored. The model will only perform structured output " "generation without calling any other tools.", ) config.pop("tools", None) config.pop("tool_config", None) config["response_mime_type"] = "application/json" config["response_schema"] = structured_model # Prepare the arguments for the Gemini API call kwargs: dict[str, JSONSerializableObject] = { "model": self.model_name, "contents": messages, "config": config, } start_datetime = datetime.now() if self.stream: response = await self.client.aio.models.generate_content_stream( **kwargs, ) return self._parse_gemini_stream_generation_response( start_datetime, response, structured_model, ) # non-streaming response = await self.client.aio.models.generate_content( **kwargs, ) parsed_response = self._parse_gemini_generation_response( start_datetime, response, structured_model, ) return parsed_response async def _parse_gemini_stream_generation_response( self, start_datetime: datetime, response: AsyncIterator[GenerateContentResponse], structured_model: Type[BaseModel] | None = None, ) -> AsyncGenerator[ChatResponse, None]: """Given a Gemini streaming generation response, extract the content blocks and usages from it and yield ChatResponse objects. Args: start_datetime (`datetime`): The start datetime of the response generation. response (`AsyncIterator[GenerateContentResponse]`): Gemini GenerateContentResponse async iterator to parse. structured_model (`Type[BaseModel] | None`, default `None`): A Pydantic BaseModel class that defines the expected structure for the model's output. Returns: `AsyncGenerator[ChatResponse, None]`: An async generator that yields ChatResponse objects containing the content blocks and usage information for each chunk in the streaming response. .. note:: If `structured_model` is not `None`, the expected structured output will be stored in the metadata of the `ChatResponse`. """ text = "" thinking = "" tool_calls: list[ToolUseBlock] = [] metadata: dict | None = None async for chunk in response: if ( chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts ): for part in chunk.candidates[0].content.parts: if part.text: if part.thought: thinking += part.text else: text += part.text if part.function_call: keyword_args = part.function_call.args or {} # .. note:: Gemini API always returns None for # function_call.id, so we use thought_signature # as the unique identifier for tool # calls when available. That maybe # infeasible someday, but Gemini # requires the thought_signature for some # llms like gemini-3-pro if part.thought_signature: call_id = base64.b64encode( part.thought_signature, ).decode("utf-8") else: call_id = part.function_call.id tool_calls.append( ToolUseBlock( type="tool_use", id=call_id, name=part.function_call.name, input=keyword_args, raw_input=json.dumps( keyword_args, ensure_ascii=False, ), ), ) # Text parts if text and structured_model: metadata = _json_loads_with_repair(text) usage = None if chunk.usage_metadata: usage = ChatUsage( input_tokens=chunk.usage_metadata.prompt_token_count, output_tokens=chunk.usage_metadata.total_token_count - chunk.usage_metadata.prompt_token_count, time=(datetime.now() - start_datetime).total_seconds(), ) # The content blocks for the current chunk content_blocks: list = [] if thinking: content_blocks.append( ThinkingBlock( type="thinking", thinking=thinking, ), ) if text: content_blocks.append( TextBlock( type="text", text=text, ), ) yield ChatResponse( content=content_blocks + tool_calls, usage=usage, metadata=metadata, ) def _parse_gemini_generation_response( self, start_datetime: datetime, response: GenerateContentResponse, structured_model: Type[BaseModel] | None = None, ) -> ChatResponse: """Given a Gemini chat completion response object, extract the content blocks and usages from it. Args: start_datetime (`datetime`): The start datetime of the response generation. response (`GenerateContentResponse`): The Gemini generation response object to parse. structured_model (`Type[BaseModel] | None`, default `None`): A Pydantic BaseModel class that defines the expected structure for the model's output. Returns: ChatResponse (`ChatResponse`): A ChatResponse object containing the content blocks and usage. .. note:: If `structured_model` is not `None`, the expected structured output will be stored in the metadata of the `ChatResponse`. """ content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = [] metadata: dict | None = None tool_calls: list = [] if ( response.candidates and response.candidates[0].content and response.candidates[0].content.parts ): for part in response.candidates[0].content.parts: if part.text: if part.thought: content_blocks.append( ThinkingBlock( type="thinking", thinking=part.text, ), ) else: content_blocks.append( TextBlock( type="text", text=part.text, ), ) if part.function_call: keyword_args = part.function_call.args or {} # .. note:: Gemini API always returns None for # function_call.id, so we use thought_signature # as the unique identifier for tool # calls when available. That maybe infeasible # someday, but Gemini requires the thought_signature # for some llms like gemini-3-pro if part.thought_signature: call_id = base64.b64encode( part.thought_signature, ).decode("utf-8") else: call_id = part.function_call.id tool_calls.append( ToolUseBlock( type="tool_use", id=call_id, name=part.function_call.name, input=keyword_args, raw_input=json.dumps( keyword_args, ensure_ascii=False, ), ), ) # For the structured output case if response.text and structured_model: metadata = _json_loads_with_repair(response.text) if response.usage_metadata: usage = ChatUsage( input_tokens=response.usage_metadata.prompt_token_count, output_tokens=response.usage_metadata.total_token_count - response.usage_metadata.prompt_token_count, time=(datetime.now() - start_datetime).total_seconds(), ) else: usage = None return ChatResponse( content=content_blocks + tool_calls, usage=usage, metadata=metadata, ) def _format_tools_json_schemas( self, schemas: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Format the tools JSON schema into required format for Gemini API. .. note:: Gemini API does not support `$defs` and `$ref` in JSON schemas. This function resolves all `$ref` references by inlining the referenced definitions, producing a self-contained schema without any references. Args: schemas (`dict[str, Any]`): The tools JSON schemas. Returns: List[Dict[str, Any]]: A list containing a dictionary with the "function_declarations" key, which maps to a list of function definitions. Example: .. code-block:: python :caption: Example tool schemas of Gemini API # Input JSON schema schemas = [ { 'type': 'function', 'function': { 'name': 'execute_shell_command', 'description': 'xxx', 'parameters': { 'type': 'object', 'properties': { 'command': { 'type': 'string', 'description': 'xxx.' }, 'timeout': { 'type': 'integer', 'default': 300 } }, 'required': ['command'] } } } ] # Output format (Gemini API expected): [ { 'function_declarations': [ { 'name': 'execute_shell_command', 'description': 'xxx.', 'parameters': { 'type': 'object', 'properties': { 'command': { 'type': 'string', 'description': 'xxx.' }, 'timeout': { 'type': 'integer', 'default': 300 } }, 'required': ['command'] } } ] } ] """ function_declarations = [] for schema in schemas: if "function" not in schema: continue func = schema["function"].copy() # Flatten the parameters schema to resolve $ref references if "parameters" in func: func["parameters"] = _flatten_json_schema(func["parameters"]) function_declarations.append(func) return [{"function_declarations": function_declarations}] def _format_tool_choice( self, tool_choice: Literal["auto", "none", "required"] | str | None, ) -> dict | None: """Format tool_choice parameter for API compatibility. Args: tool_choice (`Literal["auto", "none", "required"] | str | None`, \ default `None`): Controls which (if any) tool is called by the model. Can be "auto", "none", "required", or specific tool name. For more details, please refer to https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes Returns: `dict | None`: The formatted tool choice configuration dict, or None if tool_choice is None. """ if tool_choice is None: return None mode_mapping = { "auto": "AUTO", "none": "NONE", "required": "ANY", } mode = mode_mapping.get(tool_choice) if mode: return {"function_calling_config": {"mode": mode}} return { "function_calling_config": { "mode": "ANY", "allowed_function_names": [tool_choice], }, }