chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled

This commit is contained in:
codex-bot
2026-03-02 22:32:27 +08:00
commit a64378956a
584 changed files with 93604 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
"""The formatter module in agentscope."""
from ._formatter_base import FormatterBase
from ._truncated_formatter_base import TruncatedFormatterBase
from ._dashscope_formatter import (
DashScopeChatFormatter,
DashScopeMultiAgentFormatter,
)
from ._anthropic_formatter import (
AnthropicChatFormatter,
AnthropicMultiAgentFormatter,
)
from ._openai_formatter import (
OpenAIChatFormatter,
OpenAIMultiAgentFormatter,
)
from ._gemini_formatter import (
GeminiChatFormatter,
GeminiMultiAgentFormatter,
)
from ._ollama_formatter import (
OllamaChatFormatter,
OllamaMultiAgentFormatter,
)
from ._deepseek_formatter import (
DeepSeekChatFormatter,
DeepSeekMultiAgentFormatter,
)
from ._a2a_formatter import A2AChatFormatter
__all__ = [
"FormatterBase",
"TruncatedFormatterBase",
"DashScopeChatFormatter",
"DashScopeMultiAgentFormatter",
"OpenAIChatFormatter",
"OpenAIMultiAgentFormatter",
"AnthropicChatFormatter",
"AnthropicMultiAgentFormatter",
"GeminiChatFormatter",
"GeminiMultiAgentFormatter",
"OllamaChatFormatter",
"OllamaMultiAgentFormatter",
"DeepSeekChatFormatter",
"DeepSeekMultiAgentFormatter",
"A2AChatFormatter",
]

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
"""The A2A message formatter class."""
import mimetypes
import uuid
from typing import Literal, TYPE_CHECKING
from .._logging import logger
from ._formatter_base import FormatterBase
from ..message import (
Msg,
TextBlock,
URLSource,
Base64Source,
ContentBlock,
)
if TYPE_CHECKING:
from a2a.types import (
Message,
Task,
Part,
)
else:
Message = "a2a.types.Message"
Task = "a2a.types.Task"
Part = "a2a.types.Part"
class A2AChatFormatter(FormatterBase):
"""A2A message formatter class, which convert AgentScope messages into
A2A message format."""
async def format(self, msgs: list[Msg]) -> Message:
"""Convert AgentScope messages into a A2A message object. Note that
A2A server only supports single request message, so the input msgs
list will be merged into a single A2A Message.
.. note:: Note the A2A protocol receives a single message per request,
so multi-message inputs will be merged into one A2A Message with role
'user'.
Args:
msgs (`list[Msg]`):
List of AgentScope Msg objects to be converted.
Returns:
`Message`:
The converted A2A Message object.
"""
from a2a.types import (
Part,
TextPart,
FilePart,
FileWithUri,
FileWithBytes,
DataPart,
Role,
Message,
)
self.assert_list_of_msgs(msgs)
parts = []
for msg in msgs:
for block in msg.get_content_blocks():
block_type = block.get("type")
if block_type == "text" and block.get("text"):
parts.append(
Part(
root=TextPart(
text=block.get("text"),
),
),
)
elif block_type == "thinking" and block.get("thinking"):
parts.append(
Part(
root=TextPart(
text=block.get("thinking"),
),
),
)
elif block_type in [
"image",
"video",
"audio",
] and block.get("source"):
source = block.get("source", {})
source_type = source.get("type")
if source_type == "url":
parts.append(
Part(
root=FilePart(
file=FileWithUri(
uri=source.get("url"),
),
),
),
)
elif source_type == "base64":
parts.append(
Part(
root=FilePart(
file=FileWithBytes(
bytes=source.get("data"),
mime_type=source.get("media_type"),
),
),
),
)
else:
raise ValueError(
f"Unsupported source type: {source_type}",
)
elif block_type in ["tool_use", "tool_result"]:
parts.append(
Part(
root=DataPart(
data=block,
),
),
)
else:
logger.error(
"Unsupported block type %s in A2AFormatter.",
block_type,
)
a2a_message = Message(
message_id=str(uuid.uuid4()),
role=Role.user,
parts=parts,
)
return a2a_message
async def format_a2a_message(self, name: str, message: Message) -> Msg:
"""Convert A2A Message object back to AgentScope Msg format.
Args:
name (`str`):
The name of the message sender.
message (`Message`):
The A2A Message object to be converted.
Returns:
`list[Msg]`:
List of converted AgentScope Msg objects.
"""
from a2a.types import Role
content = []
metadata = None
for part in message.parts:
content.append(
await self._format_a2a_part(part),
)
if message.role == Role.user:
role: Literal["user", "assistant"] = "user"
elif message.role == Role.agent:
role = "assistant"
else:
raise ValueError(
f"Unsupported role: {message.role} in A2A message.",
)
return Msg(
name=name,
role=role,
content=content,
metadata=metadata,
)
@staticmethod
def _guess_type(
uri: str | None = None,
mime_type: str | None = None,
) -> Literal["image", "video", "audio", "unknown"]:
"""Guess the content type from the uri or mime type.
Args:
uri (`str | None`, optional):
The uri of the content.
mime_type (`str | None`, optional):
The mime type of the content.
Returns:
`Literal["image", "video", "audio", "unknown"]`:
The guessed content type.
"""
if mime_type is None and uri is None:
raise ValueError(
"Either uri or mime_type must be provided to guess the"
" content type.",
)
if mime_type is None:
mime_type, _encoding = mimetypes.guess_type(uri or "")
if isinstance(mime_type, str):
if mime_type.startswith("image/"):
return "image"
if mime_type.startswith("video/"):
return "video"
if mime_type.startswith("audio/"):
return "audio"
return "unknown"
async def format_a2a_task(self, name: str, task: Task) -> list[Msg]:
"""Convert A2A Task object back to AgentScope Msg format.
Args:
name (`str`):
The name of the message sender.
task (`Task`):
The A2A Task object to be converted.
Returns:
`list[Msg]`:
Converted AgentScope Msg objects.
"""
msgs = []
if task.status and task.status.message:
msgs.append(
await self.format_a2a_message(name, task.status.message),
)
merged_msgs = []
for msg in msgs:
if merged_msgs and merged_msgs[-1].role == msg.role:
merged_msgs[-1].content.extend(msg.content)
else:
merged_msgs.append(msg)
if task.artifacts:
for artifact in task.artifacts:
artifact_content = [
await self._format_a2a_part(_) for _ in artifact.parts
]
if merged_msgs and merged_msgs[-1].role == "assistant":
merged_msgs[-1].content.extend(artifact_content)
merged_msgs[-1].metadata = artifact.metadata
else:
merged_msgs.append(
Msg(
name=name,
role="assistant",
content=artifact_content,
metadata=artifact.metadata,
),
)
return merged_msgs
async def _format_a2a_part(self, part: Part) -> ContentBlock:
"""Convert a single A2A Part object into AgentScope ContentBlock.
.. note:: We will try to convert the `DataPart` into tool use and tool
result blocks if possible.
Args:
part (`Part`):
The A2A Part object to be converted.
Returns:
`ContentBlock`:
The converted AgentScope ContentBlock.
"""
from a2a.types import (
TextPart,
FilePart,
FileWithUri,
FileWithBytes,
DataPart,
)
if isinstance(part.root, TextPart):
return TextBlock(
type="text",
text=part.root.text,
)
if isinstance(part.root, FilePart):
if isinstance(part.root.file, FileWithUri):
return { # type: ignore[return-value, misc]
"type": self._guess_type(
part.root.file.uri,
part.root.file.mime_type,
),
"source": URLSource(
type="url",
url=part.root.file.uri,
),
}
if isinstance(part.root.file, FileWithBytes):
return { # type: ignore[return-value, misc]
"type": self._guess_type(
mime_type=part.root.file.mime_type,
),
"source": Base64Source(
type="base64",
media_type=part.root.file.mime_type
or "application/octet-stream",
data=part.root.file.bytes,
),
}
raise ValueError(
f"Unsupported File type: {type(part.root.file)} in A2A"
"message.",
)
if isinstance(part.root, DataPart):
# Maybe the tool use and tool result blocks
if {
"type",
"name",
"input",
"id",
} <= part.root.data.keys() and part.root.data[
"type"
] == "tool_use":
return part.root.data
if {
"type",
"name",
"output",
"id",
} <= part.root.data.keys() and part.root.data[
"type"
] == "tool_result":
return part.root.data
# TODO: what about the other data parts?
return TextBlock(
type="text",
text=str(part.root.data),
)
raise ValueError(
f"Unsupported Part type: {type(part.root)} in A2A message"
f": {part.root}",
)

View File

@@ -0,0 +1,253 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Anthropic formatter module."""
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
class AnthropicChatFormatter(TruncatedFormatterBase):
"""The Anthropic formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Anthropic API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
.. note:: Anthropic suggests always passing all previous thinking
blocks back to the API in subsequent calls to maintain reasoning
continuity. For more details, please refer to
`Anthropic's documentation
<https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#preserving-thinking-blocks>`_.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for index, msg in enumerate(msgs):
content_blocks = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ in ["thinking", "text", "image"]:
content_blocks.append({**block})
elif typ == "tool_use":
content_blocks.append(
{
"id": block.get("id"),
"type": "tool_use",
"name": block.get("name"),
"input": block.get("input", {}),
},
)
elif typ == "tool_result":
output = block.get("output")
if output is None:
content_value = [{"type": "text", "text": None}]
elif isinstance(output, list):
content_value = output
else:
content_value = [{"type": "text", "text": str(output)}]
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": block.get("id"),
"content": content_value,
},
],
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
# Claude only allow the first message to be system message
if msg.role == "system" and index != 0:
role = "user"
else:
role = msg.role
msg_anthropic = {
"role": role,
"content": content_blocks or None,
}
# When both content and tool_calls are None, skipped
if msg_anthropic["content"] or msg_anthropic.get("tool_calls"):
messages.append(msg_anthropic)
return messages
class AnthropicMultiAgentFormatter(TruncatedFormatterBase):
"""
Anthropic formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Anthropic API."""
return await AnthropicChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Anthropic API."""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Anthropic format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{
"text": "\n".join(accumulated_text),
"type": "text",
},
)
accumulated_text.clear()
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{
"text": "\n".join(accumulated_text),
"type": "text",
},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"type": "text",
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append(
{"type": "text", "text": "</history>"},
)
if conversation_blocks:
formatted_msgs.append(
{
"role": "user",
"content": conversation_blocks,
},
)
return formatted_msgs

View File

@@ -0,0 +1,639 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The dashscope formatter module."""
import json
import os.path
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _is_accessible_local_file
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
VideoBlock,
ToolUseBlock,
ToolResultBlock,
URLSource,
)
from ..token import TokenCounterBase
def _format_dashscope_media_block(
block: ImageBlock | AudioBlock,
) -> dict[str, str]:
"""Format an image or audio block for DashScope API.
Args:
block (`ImageBlock` | `AudioBlock`):
The image or audio block to format.
Returns:
`dict[str, str]`:
A dictionary with "image" or "audio" key and the formatted URL or
data URI as value.
Raises:
`NotImplementedError`:
If the source type is not supported.
"""
typ = block["type"]
source = block["source"]
if source["type"] == "url":
url = source["url"]
if _is_accessible_local_file(url):
return {typ: "file://" + os.path.abspath(url)}
else:
# treat as web url
return {typ: url}
elif source["type"] == "base64":
media_type = source["media_type"]
base64_data = source["data"]
return {
typ: f"data:{media_type};base64,{base64_data}",
}
else:
raise NotImplementedError(
f"Unsupported source type '{source.get('type')}' "
f"for {typ} block.",
)
def _reformat_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Reformat the content to be compatible with HuggingFaceTokenCounter.
This function processes a list of messages and converts multi-part
text content into single string content when all parts are plain text.
This is necessary for compatibility with HuggingFaceTokenCounter which
expects simple string content rather than structured content with
multiple parts.
Args:
messages (list[dict[str, Any]]):
A list of message dictionaries where each message may contain a
"content" field. The content can be either:
- A string (unchanged)
- A list of content items, where each item is a dict that may
contain "text", "type", and other fields
Returns:
list[dict[str, Any]]:
A list of reformatted messages. For messages where all content
items are plain text (have "text" field and either no "type"
field or "type" == "text"), the content list is converted to a
single newline-joined string. Other messages remain unchanged.
Example:
.. code-block:: python
# Case 1: All text content - will be converted
messages = [
{
"role": "user",
"content": [
{"text": "Hello", "type": "text"},
{"text": "World", "type": "text"}
]
}
]
result = _reformat_messages(messages)
print(result[0]["content"])
# Output: "Hello\nWorld"
# Case 2: Mixed content - will remain unchanged
messages = [
{
"role": "user",
"content": [
{"text": "Hello", "type": "text"},
{"image_url": "...", "type": "image"}
]
}
]
result = _reformat_messages(messages) # remain unchanged
print(type(result[0]["content"]))
# Output: <class 'list'>
"""
for message in messages:
content = message.get("content", [])
is_all_text = True
texts = []
for item in content:
if not isinstance(item, dict) or "text" not in item:
is_all_text = False
break
if "type" in item and item["type"] != "text":
is_all_text = False
break
if item["text"]:
texts.append(item["text"])
if is_all_text and texts:
message["content"] = "\n".join(texts)
return messages
class DashScopeChatFormatter(TruncatedFormatterBase):
"""The DashScope formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
.. warning::
Known Issues with DashScope API:
1. **Missing content field**: When messages lack the 'content' field,
qwen-vl-max models will raise ``KeyError: 'content'``.
2. **None content value**: When content is ``None``, qwen-vl-max models
will raise ``TypeError: 'NoneType' object is not iterable``.
3. **Empty text in content**: When content contains
``[{"text": None}]``, qwen3-max may repeatedly invoke tools
multiple times. Note that when qwen3-max initiates tool calls,
the returned message contains ``"content": ""``.
To avoid these issues, this formatter assigns content as an empty
list ``[]`` for messages without valid content blocks.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
VideoBlock,
ToolUseBlock,
ToolResultBlock,
]
def __init__(
self,
promote_tool_result_images: bool = False,
promote_tool_result_audios: bool = False,
promote_tool_result_videos: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_audios (`bool`, defaults to `False`):
Whether to promote audios from tool results to user messages.
Most LLM APIs don't support audios in tool result blocks, but
do support them in user message blocks. When `True`, audios are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_videos (`bool`, defaults to `False`):
Whether to promote videos from tool results to user messages.
Most LLM APIs don't support videos in tool result blocks, but
do support them in user message blocks. When `True`, videos are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter, max_tokens)
self.promote_tool_result_images = promote_tool_result_images
self.promote_tool_result_audios = promote_tool_result_audios
self.promote_tool_result_videos = promote_tool_result_videos
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into DashScope API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
formatted_msgs: list[dict] = []
i = 0
while i < len(msgs):
msg = msgs[i]
content_blocks: list[dict[str, Any]] = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append(
{
"text": block.get("text"),
},
)
elif typ in ["image", "audio", "video"]:
content_blocks.append(
_format_dashscope_media_block(
block, # type: ignore[arg-type]
),
)
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
# First add the tool result message in DashScope API format
formatted_msgs.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": textual_output,
"name": block.get("name"),
},
)
# Then, handle the multimodal data if any
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
elif (
multimodal_block["type"] == "audio"
and self.promote_tool_result_audios
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The audio from '{url}': ",
),
AudioBlock(
type="audio",
source=URLSource(
type="url",
url=url,
),
),
],
)
elif (
multimodal_block["type"] == "video"
and self.promote_tool_result_videos
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The video from '{url}': ",
),
VideoBlock(
type="video",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
f"the media contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
msg_dashscope = {
"role": msg.role,
"content": content_blocks,
}
if tool_calls:
msg_dashscope["tool_calls"] = tool_calls
if msg_dashscope["content"] or msg_dashscope.get("tool_calls"):
formatted_msgs.append(msg_dashscope)
# Move to next message
i += 1
return _reformat_messages(formatted_msgs)
class DashScopeMultiAgentFormatter(TruncatedFormatterBase):
"""DashScope formatter for multi-agent conversations, where more than
a user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the DashScope API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
AudioBlock,
VideoBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
promote_tool_result_audios: bool = False,
promote_tool_result_videos: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_audios (`bool`, defaults to `False`):
Whether to promote audios from tool results to user messages.
Most LLM APIs don't support audios in tool result blocks, but
do support them in user message blocks. When `True`, audios are
extracted and appended as a separate user message with
explanatory text indicating their source.
promote_tool_result_videos (`bool`, defaults to `False`):
Whether to promote videos from tool results to user messages.
Most LLM APIs don't support videos in tool result blocks, but
do support them in user message blocks. When `True`, videos are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
self.promote_tool_result_audios = promote_tool_result_audios
self.promote_tool_result_videos = promote_tool_result_videos
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the DashScope API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DashScope API.
"""
return await DashScopeChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
promote_tool_result_audios=self.promote_tool_result_audios,
promote_tool_result_videos=self.promote_tool_result_videos,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into a user message with conversation history tags. For the
first agent message, it will include the conversation history prompt.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DashScope API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required DashScope format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "audio", "video"]:
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
if block["source"]["type"] == "url":
url = block["source"]["url"]
if _is_accessible_local_file(url):
conversation_blocks.append(
{
block["type"]: "file://"
+ os.path.abspath(url),
},
)
else:
conversation_blocks.append({block["type"]: url})
elif block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
conversation_blocks.append(
{
block[
"type"
]: f"data:{media_type};base64,{base64_data}",
},
)
else:
logger.warning(
"Unsupported block type %s in the message, "
"skipped.",
block["type"],
)
if accumulated_text:
conversation_blocks.append({"text": "\n".join(accumulated_text)})
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
formatted_msgs.append(
{
"role": "user",
"content": conversation_blocks,
},
)
return _reformat_messages(formatted_msgs)
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for DashScope API."""
return {
"role": "system",
"content": msg.get_text_content(),
}

View File

@@ -0,0 +1,265 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The DeepSeek formatter module."""
import json
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
class DeepSeekChatFormatter(TruncatedFormatterBase):
"""The DeepSeek formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = False
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into DeepSeek API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for msg in msgs:
content_blocks: list = []
reasoning_content_blocks: list = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "thinking":
reasoning_content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
textual_output, _ = self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
)
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": textual_output,
"name": block.get("name"),
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
reasoning_msg = "\n".join(
reasoning.get("thinking", "")
for reasoning in reasoning_content_blocks
)
msg_deepseek = {
"role": msg.role,
"content": content_msg or None,
}
if reasoning_msg:
msg_deepseek["reasoning_content"] = reasoning_msg
if tool_calls:
msg_deepseek["tool_calls"] = tool_calls
if msg_deepseek["content"] or msg_deepseek.get("tool_calls"):
messages.append(msg_deepseek)
return messages
class DeepSeekMultiAgentFormatter(TruncatedFormatterBase):
"""
DeepSeek formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = False
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DeepSeek multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the DeepSeek API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DeepSeek API.
"""
return await DeepSeekChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the DeepSeek API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DeepSeek API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required DeepSeek format
formatted_msgs: list[dict] = []
conversation_blocks: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
"""The formatter module."""
from abc import abstractmethod
from typing import Any, List, Tuple, Sequence
from .._utils._common import _save_base64_data
from ..message import Msg, AudioBlock, ImageBlock, TextBlock, VideoBlock
class FormatterBase:
"""The base class for formatters."""
@abstractmethod
async def format(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]:
"""Format the Msg objects to a list of dictionaries that satisfy the
API requirements."""
@staticmethod
def assert_list_of_msgs(msgs: list[Msg]) -> None:
"""Assert that the input is a list of Msg objects.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be validated.
"""
if not isinstance(msgs, list):
raise TypeError("Input must be a list of Msg objects.")
for msg in msgs:
if not isinstance(msg, Msg):
raise TypeError(
f"Expected Msg object, got {type(msg)} instead.",
)
@staticmethod
def convert_tool_result_to_string(
output: str | List[TextBlock | ImageBlock | AudioBlock | VideoBlock],
) -> tuple[
str,
Sequence[
Tuple[
str,
ImageBlock | AudioBlock | TextBlock | VideoBlock,
]
],
]:
"""Turn the tool result list into a textual output to be compatible
with the LLM API that doesn't support multimodal data in the tool
result.
For URL-based images, the URL is included in the list. For
base64-encoded images, the local file path where the image is saved
is included in the returned list.
Args:
output (`str | List[TextBlock | ImageBlock | AudioBlock | \
VideoBlock]`):
The output of the tool response, including text and multimodal
data like images and audio.
Returns:
`tuple[str, list[Tuple[str, ImageBlock | AudioBlock | VideoBlock \
TextBlock]]]`:
A tuple containing the textual representation of the tool
result and a list of tuples. The first element of each tuple
is the local file path or URL of the multimodal data, and the
second element is the corresponding block.
"""
if isinstance(output, str):
return output, []
textual_output = []
multimodal_data = []
for block in output:
assert isinstance(block, dict) and "type" in block, (
f"Invalid block: {block}, a TextBlock, ImageBlock, "
f"AudioBlock, or VideoBlock is expected."
)
if block["type"] == "text":
textual_output.append(block["text"])
elif block["type"] in ["image", "audio", "video"]:
assert "source" in block, (
f"Invalid {block['type']} block: {block}, 'source' key "
"is required."
)
source = block["source"]
# Save the image locally and return the file path
if source["type"] == "url":
textual_output.append(
f"The returned {block['type']} can be found "
f"at: {source['url']}",
)
path_multimodal_file = source["url"]
elif source["type"] == "base64":
path_multimodal_file = _save_base64_data(
source["media_type"],
source["data"],
)
textual_output.append(
f"The returned {block['type']} can be found "
f"at: {path_multimodal_file}",
)
else:
raise ValueError(
f"Invalid image source: {block['source']}, "
"expected 'url' or 'base64'.",
)
multimodal_data.append(
(path_multimodal_file, block),
)
else:
raise ValueError(
f"Unsupported block type: {block['type']}, "
"expected 'text', 'image', 'audio', or 'video'.",
)
if len(textual_output) == 1:
return textual_output[0], multimodal_data
else:
return "\n".join("- " + _ for _ in textual_output), multimodal_data

View File

@@ -0,0 +1,507 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Google gemini API formatter in agentscope."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._utils._common import _get_bytes_from_web_url
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
VideoBlock,
URLSource,
)
from .._logging import logger
from ..token import TokenCounterBase
def _format_gemini_media_block(
media_block: ImageBlock | AudioBlock | VideoBlock,
) -> dict[str, Any]:
"""Format an image/audio/video block for Gemini API.
Args:
media_block (`ImageBlock | AudioBlock | VideoBlock`):
The media block to format.
Returns:
`dict[str, Any]`:
A dictionary with "inline_data" key in Gemini format.
Raises:
`ValueError`:
If the source type is not supported.
"""
source = media_block["source"]
if source["type"] == "base64":
return {
"inline_data": {
"data": source["data"],
"mime_type": source["media_type"],
},
}
elif source["type"] == "url":
return {
"inline_data": _to_gemini_inline_data(source["url"]),
}
else:
raise ValueError(
f"Unsupported source type: {source['type']}",
)
def _to_gemini_inline_data(url: str) -> dict:
"""Convert url into the Gemini API required format."""
parsed_url = urlparse(url)
extension = url.split(".")[-1].lower()
# Pre-calculate media type from extension (image/audio/video).
typ = None
for k, v in GeminiChatFormatter.supported_extensions.items():
if extension in v:
typ = k
break
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
data = _get_bytes_from_web_url(url)
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
elif os.path.exists(url):
# Local file
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
class GeminiChatFormatter(TruncatedFormatterBase):
"""The Gemini formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
supported_extensions: dict[str, list[str]] = {
"image": ["png", "jpeg", "webp", "heic", "heif"],
"video": [
"mp4",
"mpeg",
"mov",
"avi",
"x-flv",
"mpg",
"webm",
"wmv",
"3gpp",
],
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
}
def __init__(
self,
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Gemini chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter, max_tokens)
self.promote_tool_result_images = promote_tool_result_images
async def _format(
self,
msgs: list[Msg],
) -> list[dict]:
"""Format message objects into Gemini API required format."""
self.assert_list_of_msgs(msgs)
messages: list = []
i = 0
while i < len(msgs):
msg = msgs[i]
parts = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
parts.append(
{
"text": block.get("text"),
},
)
elif typ == "tool_use":
parts.append(
{
"function_call": {
"id": None,
"name": block["name"],
"args": block["input"],
},
"thought_signature": block.get("id", None),
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
# First add the tool result message in DashScope API format
messages.append(
{
"role": "user",
"parts": [
{
"function_response": {
"id": block["id"],
"name": block["name"],
"response": {
"output": textual_output,
},
},
},
],
},
)
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
"the image contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
elif typ in ["image", "audio", "video"]:
parts.append(
_format_gemini_media_block(
block, # type: ignore[arg-type]
),
)
else:
logger.warning(
"Unsupported block type: %s in the message, skipped. ",
typ,
)
role = "model" if msg.role == "assistant" else "user"
if parts:
messages.append(
{
"role": role,
"parts": parts,
},
)
# Move to next message (including inserted messages, which will
# be processed in subsequent iterations)
i += 1
return messages
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
"""The multi-agent formatter for Google Gemini API, where more than a
user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the Gemini API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Gemini multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to be used for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Gemini API."""
return {
"role": "user",
"parts": [
{
"text": msg.get_text_content(),
},
],
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
return await GeminiChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into Gemini API required format
formatted_msgs: list = []
# Collect the multimodal files
conversation_parts: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "video", "audio"]:
# handle the accumulated text as a single part if exists
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
accumulated_text.clear()
# handle the multimodal data
conversation_parts.append(
_format_gemini_media_block(
block, # type: ignore[arg-type]
),
)
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
# Add prompt and <history></history> tags around conversation history
if conversation_parts:
if conversation_parts[0].get("text"):
conversation_parts[0]["text"] = (
conversation_history_prompt
+ "<history>"
+ conversation_parts[0]["text"]
)
else:
conversation_parts.insert(
0,
{"text": conversation_history_prompt + "<history>"},
)
if conversation_parts[-1].get("text"):
conversation_parts[-1]["text"] += "\n</history>"
else:
conversation_parts.append(
{"text": "</history>"},
)
formatted_msgs.append(
{
"role": "user",
"parts": conversation_parts,
},
)
return formatted_msgs

View File

@@ -0,0 +1,441 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Ollama formatter module."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _get_bytes_from_web_url
from ..message import (
Msg,
TextBlock,
ImageBlock,
ToolUseBlock,
ToolResultBlock,
URLSource,
)
from ..token import TokenCounterBase
def _format_ollama_image_block(
image_block: ImageBlock,
) -> str:
"""Format an image block for Ollama API.
Args:
image_block (`ImageBlock`):
The image block to format.
Returns:
`str`:
Base64 encoded image data as a string.
Raises:
`ValueError`:
If the source type is not supported.
"""
source = image_block["source"]
if source["type"] == "url":
return _convert_ollama_image_url_to_base64_data(source["url"])
elif source["type"] == "base64":
return source["data"]
else:
raise ValueError(
f"Unsupported image source type: {source['type']}",
)
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
"""Convert image url to base64."""
parsed_url = urlparse(url)
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
data = _get_bytes_from_web_url(url)
return data
if os.path.exists(url):
# Local file
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return data
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
class OllamaChatFormatter(TruncatedFormatterBase):
"""The Ollama formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `role` field to identify different
participants in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Ollama chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter, max_tokens)
self.promote_tool_result_images = promote_tool_result_images
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Ollama API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list = []
i = 0
while i < len(msgs):
msg = msgs[i]
content_blocks: list = []
tool_calls = []
images = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": block.get("input", {}),
},
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": textual_output,
"name": block.get("name"),
},
)
# Then, handle the multimodal data if any
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
"the image contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
elif typ == "image":
images.append(
_format_ollama_image_block(
block, # type: ignore[arg-type]
),
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
msg_ollama = {
"role": msg.role,
"content": content_msg or None,
}
if tool_calls:
msg_ollama["tool_calls"] = tool_calls
if images:
msg_ollama["images"] = images
if (
msg_ollama["content"]
or msg_ollama.get("images")
or msg_ollama.get("tool_calls")
):
messages.append(msg_ollama)
# Move to next message
i += 1
return messages
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
"""
Ollama formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Ollama multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Ollama API."""
return {
"role": "system",
"content": msg.get_text_content(),
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Ollama API.
"""
return await OllamaChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the ollama API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Ollama format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
images = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
images.append(_format_ollama_image_block(block))
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if images:
user_message["images"] = images
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,530 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches, too-many-nested-blocks
"""The OpenAI formatter for agentscope."""
import base64
import json
import os
from typing import Any
from urllib.parse import urlparse
import requests
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import (
Msg,
URLSource,
TextBlock,
ImageBlock,
AudioBlock,
Base64Source,
ToolUseBlock,
ToolResultBlock,
)
from ..token import TokenCounterBase
def _format_openai_image_block(
image_block: ImageBlock,
) -> dict[str, Any]:
"""Format an image block for OpenAI API.
Args:
image_block (`ImageBlock`):
The image block to format.
Returns:
`dict[str, Any]`:
A dictionary with "type" and "image_url" keys in OpenAI format.
Raises:
`ValueError`:
If the source type is not supported.
"""
source = image_block["source"]
if source["type"] == "url":
url = _to_openai_image_url(source["url"])
elif source["type"] == "base64":
data = source["data"]
media_type = source["media_type"]
url = f"data:{media_type};base64,{data}"
else:
raise ValueError(
f"Unsupported image source type: {source['type']}",
)
return {
"type": "image_url",
"image_url": {
"url": url,
},
}
def _to_openai_image_url(url: str) -> str:
"""Convert an image url to openai format. If the given url is a local
file, it will be converted to base64 format. Otherwise, it will be
returned directly.
Args:
url (`str`):
The local or public url of the image.
"""
# See https://platform.openai.com/docs/guides/vision for details of
# support image extensions.
support_image_extensions = (
".png",
".jpg",
".jpeg",
".gif",
".webp",
)
parsed_url = urlparse(url)
lower_url = url.lower()
# Web url
if not os.path.exists(url) and parsed_url.scheme != "":
path_lower = parsed_url.path if parsed_url.path else parsed_url.netloc
if any(path_lower.endswith(_) for _ in support_image_extensions):
return url
# Check if it is a local file
elif os.path.exists(url) and os.path.isfile(url):
if any(lower_url.endswith(_) for _ in support_image_extensions):
with open(url, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode(
"utf-8",
)
extension = parsed_url.path.lower().split(".")[-1]
mime_type = f"image/{extension}"
return f"data:{mime_type};base64,{base64_image}"
raise TypeError(f'"{url}" should end with {support_image_extensions}.')
def _to_openai_audio_data(source: URLSource | Base64Source) -> dict:
"""Covert an audio source to OpenAI format."""
if source["type"] == "url":
extension = source["url"].split(".")[-1].lower()
if extension not in ["wav", "mp3"]:
raise TypeError(
f"Unsupported audio file extension: {extension}, "
"wav and mp3 are supported.",
)
parsed_url = urlparse(source["url"])
if os.path.exists(source["url"]):
with open(source["url"], "rb") as audio_file:
data = base64.b64encode(audio_file.read()).decode("utf-8")
# web url
elif parsed_url.scheme != "":
response = requests.get(source["url"])
response.raise_for_status()
data = base64.b64encode(response.content).decode("utf-8")
else:
raise ValueError(
f"Unsupported audio source: {source['url']}, "
"it should be a local file or a web URL.",
)
return {
"data": data,
"format": extension,
}
if source["type"] == "base64":
data = source["data"]
media_type = source["media_type"]
if media_type not in ["audio/wav", "audio/mp3"]:
raise TypeError(
f"Unsupported audio media type: {media_type}, "
"only audio/wav and audio/mp3 are supported.",
)
return {
"data": data,
"format": media_type.split("/")[-1],
}
raise TypeError(f"Unsupported audio source: {source['type']}.")
class OpenAIChatFormatter(TruncatedFormatterBase):
"""The OpenAI formatter class for chatbot scenario, where only a user
and an agent are involved. We use the `name` field in OpenAI API to
identify different entities in the conversation.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversation"""
support_vision: bool = True
"""Whether support vision models"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
"""Supported message blocks for OpenAI API"""
def __init__(
self,
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the OpenAI chat formatter.
Args:
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.promote_tool_result_images = promote_tool_result_images
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into OpenAI API required format.
Args:
msgs (`list[Msg]`):
The list of Msg objects to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries, where each dictionary has "name",
"role", and "content" keys.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
i = 0
while i < len(msgs):
msg = msgs[i]
content_blocks = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
(
textual_output,
multimodal_data,
) = self.convert_tool_result_to_string(block["output"])
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": ( # type: ignore[arg-type]
textual_output
),
"name": block.get("name"),
},
)
# Then, handle the multimodal data if any
promoted_blocks: list = []
for url, multimodal_block in multimodal_data:
if (
multimodal_block["type"] == "image"
and self.promote_tool_result_images
):
promoted_blocks.extend(
[
TextBlock(
type="text",
text=f"\n- The image from '{url}': ",
),
ImageBlock(
type="image",
source=URLSource(
type="url",
url=url,
),
),
],
)
if promoted_blocks:
# Insert promoted blocks as new user message(s)
promoted_blocks = [
TextBlock(
type="text",
text="<system-info>The following are "
"the image contents from the tool "
f"result of '{block['name']}':",
),
*promoted_blocks,
TextBlock(
type="text",
text="</system-info>",
),
]
msgs.insert(
i + 1,
Msg(
name="user",
content=promoted_blocks,
role="user",
),
)
elif typ == "image":
content_blocks.append(
_format_openai_image_block(
block, # type: ignore[arg-type]
),
)
elif typ == "audio":
# Filter out audio content when the multimodal model
# outputs both text and audio, to prevent errors in
# subsequent model calls
if msg.role == "assistant":
continue
input_audio = _to_openai_audio_data(block["source"])
content_blocks.append(
{
"type": "input_audio",
"input_audio": input_audio,
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
msg_openai = {
"role": msg.role,
"name": msg.name,
"content": content_blocks or None,
}
if tool_calls:
msg_openai["tool_calls"] = tool_calls
# When both content and tool_calls are None, skipped
if msg_openai["content"] or msg_openai.get("tool_calls"):
messages.append(msg_openai)
# Move to next message
i += 1
return messages
class OpenAIMultiAgentFormatter(TruncatedFormatterBase):
"""
OpenAI formatter for multi-agent conversations, where more than
a user and an agent are involved.
.. tip:: This formatter is compatible with OpenAI API and
OpenAI-compatible services like vLLM, Azure OpenAI, and others.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversation"""
support_vision: bool = True
"""Whether support vision models"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
"""Supported message blocks for OpenAI API"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
promote_tool_result_images: bool = False,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the OpenAI multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
promote_tool_result_images (`bool`, defaults to `False`):
Whether to promote images from tool results to user messages.
Most LLM APIs don't support images in tool result blocks, but
do support them in user message blocks. When `True`, images are
extracted and appended as a separate user message with
explanatory text indicating their source.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
self.promote_tool_result_images = promote_tool_result_images
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the OpenAI API."""
return await OpenAIChatFormatter(
promote_tool_result_images=self.promote_tool_result_images,
).format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the OpenAI API."""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required OpenAI format
formatted_msgs: list[dict] = []
conversation_blocks: list = []
accumulated_text = []
images = []
audios = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
images.append(_format_openai_image_block(block))
elif block["type"] == "audio":
# Filter out audio content when the multimodal model
# outputs both text and audio, to prevent errors in
# subsequent model calls
if msg.role == "assistant":
continue
input_audio = _to_openai_audio_data(block["source"])
audios.append(
{
"type": "input_audio",
"input_audio": input_audio,
},
)
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
content_list: list[dict[str, Any]] = []
if conversation_blocks_text:
content_list.append(
{
"type": "text",
"text": conversation_blocks_text,
},
)
if images:
content_list.extend(images)
if audios:
content_list.extend(audios)
user_message = {
"role": "user",
"content": content_list,
}
if content_list:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
"""The truncated formatter base class, which allows to truncate the input
messages."""
from abc import ABC
from copy import deepcopy
from typing import (
Any,
Tuple,
Literal,
AsyncGenerator,
)
from ._formatter_base import FormatterBase
from ..message import Msg
from ..token import TokenCounterBase
from ..tracing import trace_format
class TruncatedFormatterBase(FormatterBase, ABC):
"""Base class for truncated formatters, which formats input messages into
required formats with tokens under a specified limit."""
def __init__(
self,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the TruncatedFormatterBase.
Args:
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
self.token_counter = token_counter
assert (
max_tokens is None or 0 < max_tokens
), "max_tokens must be greater than 0"
self.max_tokens = max_tokens
@trace_format
async def format(
self,
msgs: list[Msg],
**kwargs: Any,
) -> list[dict[str, Any]]:
"""Format the input messages into the required format. If token
counter and max token limit are provided, the messages will be
truncated to fit the limit.
Args:
msgs (`list[Msg]`):
The input messages to be formatted.
Returns:
`list[dict[str, Any]]`:
The formatted messages in the required format.
"""
# Check if the input messages are valid
self.assert_list_of_msgs(msgs)
msgs = deepcopy(msgs)
while True:
formatted_msgs = await self._format(msgs)
n_tokens = await self._count(formatted_msgs)
if (
n_tokens is None
or self.max_tokens is None
or n_tokens <= self.max_tokens
):
return formatted_msgs
# truncate the input messages
msgs = await self._truncate(msgs)
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
"""Format the input messages into the required format. This method
should be implemented by the subclasses."""
formatted_msgs = []
start_index = 0
if len(msgs) > 0 and msgs[0].role == "system":
formatted_msgs.append(
await self._format_system_message(msgs[0]),
)
start_index = 1
is_first_agent_message = True
async for typ, group in self._group_messages(msgs[start_index:]):
match typ:
case "tool_sequence":
formatted_msgs.extend(
await self._format_tool_sequence(group),
)
case "agent_message":
formatted_msgs.extend(
await self._format_agent_message(
group,
is_first_agent_message,
),
)
is_first_agent_message = False
return formatted_msgs
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the LLM API.
.. note:: This is the default implementation. For certain LLM APIs
with specific requirements, you may need to implement a custom
formatting function to accommodate those particular needs.
"""
return {
"role": "system",
"content": msg.get_content_blocks("text"),
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the LLM API."""
raise NotImplementedError(
"_format_tool_sequence is not implemented",
)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the LLM API."""
raise NotImplementedError(
"_format_agent_message is not implemented",
)
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
"""Truncate the input messages, so that it can fit the token limit.
This function is called only when
- both `token_counter` and `max_tokens` are provided,
- the formatted output of the input messages exceeds the token limit.
.. tip:: This function only provides a simple strategy, and developers
can override this method to implement more sophisticated
truncation strategies.
.. note:: The tool call message should be truncated together with
its corresponding tool result message to satisfy the LLM API
requirements.
Args:
msgs (`list[Msg]`):
The input messages to be truncated.
Raises:
`ValueError`:
If the system prompt message already exceeds the token limit,
or if there are tool calls without corresponding tool results.
Returns:
`list[Msg]`:
The truncated messages.
"""
start_index = 0
if len(msgs) > 0 and msgs[0].role == "system":
if len(msgs) == 1:
# If the system prompt already exceeds the token limit, we
# raise an error.
raise ValueError(
f"The system prompt message already exceeds the token "
f"limit ({self.max_tokens} tokens).",
)
start_index = 1
# Create a tool call IDs queues to delete the corresponding tool
# result message
tool_call_ids = set()
for i in range(start_index, len(msgs)):
msg = msgs[i]
for block in msg.get_content_blocks("tool_use"):
tool_call_ids.add(block["id"])
for block in msg.get_content_blocks("tool_result"):
try:
tool_call_ids.remove(block["id"])
except KeyError:
pass
# We can stop truncating if the queue is empty
if len(tool_call_ids) == 0:
return msgs[:start_index] + msgs[i + 1 :]
if len(tool_call_ids) > 0:
raise ValueError(
"The input messages contains tool call(s) that do not have "
f"the corresponding tool result(s): {tool_call_ids}. ",
)
return msgs[:start_index]
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
"""Count the number of tokens in the input messages. If token counter
is not provided, `None` will be returned.
Args:
msgs (`list[Msg]`):
The input messages to count tokens for.
"""
if self.token_counter is None:
return None
return await self.token_counter.count(msgs)
@staticmethod
async def _group_messages(
msgs: list[Msg],
) -> AsyncGenerator[
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
None,
]:
"""Group the input messages into two types and yield them as a
generator. The two types are:
- agent message that doesn't contain tool calls/results, and
- tool sequence that consisted of a sequence of tool calls/results
.. note:: The group operation is used in multi-agent scenario, where
multiple entities are involved in the input messages. So that to be
compatible with tools API, we have to group the messages and format
them with different strategies.
Args:
msgs (`list[Msg]`):
The input messages to be grouped, where the system prompt
message shouldn't be included.
Yields:
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
A generator that yields tuples of group type and the list of
messages in that group. The group type can be either
"tool_sequence" or "agent_message".
"""
group_type: Literal["tool_sequence", "agent_message"] | None = None
group = []
for msg in msgs:
if group_type is None:
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
group_type = "tool_sequence"
else:
group_type = "agent_message"
group.append(msg)
continue
# determine if this msg has the same type as the current group
if group_type == "tool_sequence":
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
group.append(msg)
else:
yield group_type, group
group = [msg]
group_type = "agent_message"
elif group_type == "agent_message":
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
yield group_type, group
group = [msg]
group_type = "tool_sequence"
else:
group.append(msg)
if group_type:
yield group_type, group