chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
48
src/agentscope/formatter/__init__.py
Normal file
48
src/agentscope/formatter/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The formatter module in agentscope."""
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from ._dashscope_formatter import (
|
||||
DashScopeChatFormatter,
|
||||
DashScopeMultiAgentFormatter,
|
||||
)
|
||||
from ._anthropic_formatter import (
|
||||
AnthropicChatFormatter,
|
||||
AnthropicMultiAgentFormatter,
|
||||
)
|
||||
from ._openai_formatter import (
|
||||
OpenAIChatFormatter,
|
||||
OpenAIMultiAgentFormatter,
|
||||
)
|
||||
from ._gemini_formatter import (
|
||||
GeminiChatFormatter,
|
||||
GeminiMultiAgentFormatter,
|
||||
)
|
||||
from ._ollama_formatter import (
|
||||
OllamaChatFormatter,
|
||||
OllamaMultiAgentFormatter,
|
||||
)
|
||||
from ._deepseek_formatter import (
|
||||
DeepSeekChatFormatter,
|
||||
DeepSeekMultiAgentFormatter,
|
||||
)
|
||||
from ._a2a_formatter import A2AChatFormatter
|
||||
|
||||
__all__ = [
|
||||
"FormatterBase",
|
||||
"TruncatedFormatterBase",
|
||||
"DashScopeChatFormatter",
|
||||
"DashScopeMultiAgentFormatter",
|
||||
"OpenAIChatFormatter",
|
||||
"OpenAIMultiAgentFormatter",
|
||||
"AnthropicChatFormatter",
|
||||
"AnthropicMultiAgentFormatter",
|
||||
"GeminiChatFormatter",
|
||||
"GeminiMultiAgentFormatter",
|
||||
"OllamaChatFormatter",
|
||||
"OllamaMultiAgentFormatter",
|
||||
"DeepSeekChatFormatter",
|
||||
"DeepSeekMultiAgentFormatter",
|
||||
"A2AChatFormatter",
|
||||
]
|
||||
364
src/agentscope/formatter/_a2a_formatter.py
Normal file
364
src/agentscope/formatter/_a2a_formatter.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The A2A message formatter class."""
|
||||
import mimetypes
|
||||
import uuid
|
||||
from typing import Literal, TYPE_CHECKING
|
||||
|
||||
|
||||
from .._logging import logger
|
||||
from ._formatter_base import FormatterBase
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
URLSource,
|
||||
Base64Source,
|
||||
ContentBlock,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import (
|
||||
Message,
|
||||
Task,
|
||||
Part,
|
||||
)
|
||||
else:
|
||||
Message = "a2a.types.Message"
|
||||
Task = "a2a.types.Task"
|
||||
Part = "a2a.types.Part"
|
||||
|
||||
|
||||
class A2AChatFormatter(FormatterBase):
|
||||
"""A2A message formatter class, which convert AgentScope messages into
|
||||
A2A message format."""
|
||||
|
||||
async def format(self, msgs: list[Msg]) -> Message:
|
||||
"""Convert AgentScope messages into a A2A message object. Note that
|
||||
A2A server only supports single request message, so the input msgs
|
||||
list will be merged into a single A2A Message.
|
||||
|
||||
.. note:: Note the A2A protocol receives a single message per request,
|
||||
so multi-message inputs will be merged into one A2A Message with role
|
||||
'user'.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
List of AgentScope Msg objects to be converted.
|
||||
|
||||
Returns:
|
||||
`Message`:
|
||||
The converted A2A Message object.
|
||||
"""
|
||||
|
||||
from a2a.types import (
|
||||
Part,
|
||||
TextPart,
|
||||
FilePart,
|
||||
FileWithUri,
|
||||
FileWithBytes,
|
||||
DataPart,
|
||||
Role,
|
||||
Message,
|
||||
)
|
||||
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
parts = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
block_type = block.get("type")
|
||||
if block_type == "text" and block.get("text"):
|
||||
parts.append(
|
||||
Part(
|
||||
root=TextPart(
|
||||
text=block.get("text"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elif block_type == "thinking" and block.get("thinking"):
|
||||
parts.append(
|
||||
Part(
|
||||
root=TextPart(
|
||||
text=block.get("thinking"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elif block_type in [
|
||||
"image",
|
||||
"video",
|
||||
"audio",
|
||||
] and block.get("source"):
|
||||
source = block.get("source", {})
|
||||
source_type = source.get("type")
|
||||
|
||||
if source_type == "url":
|
||||
parts.append(
|
||||
Part(
|
||||
root=FilePart(
|
||||
file=FileWithUri(
|
||||
uri=source.get("url"),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elif source_type == "base64":
|
||||
parts.append(
|
||||
Part(
|
||||
root=FilePart(
|
||||
file=FileWithBytes(
|
||||
bytes=source.get("data"),
|
||||
mime_type=source.get("media_type"),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported source type: {source_type}",
|
||||
)
|
||||
|
||||
elif block_type in ["tool_use", "tool_result"]:
|
||||
parts.append(
|
||||
Part(
|
||||
root=DataPart(
|
||||
data=block,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
"Unsupported block type %s in A2AFormatter.",
|
||||
block_type,
|
||||
)
|
||||
|
||||
a2a_message = Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.user,
|
||||
parts=parts,
|
||||
)
|
||||
|
||||
return a2a_message
|
||||
|
||||
async def format_a2a_message(self, name: str, message: Message) -> Msg:
|
||||
"""Convert A2A Message object back to AgentScope Msg format.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the message sender.
|
||||
message (`Message`):
|
||||
The A2A Message object to be converted.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
List of converted AgentScope Msg objects.
|
||||
"""
|
||||
|
||||
from a2a.types import Role
|
||||
|
||||
content = []
|
||||
metadata = None
|
||||
for part in message.parts:
|
||||
content.append(
|
||||
await self._format_a2a_part(part),
|
||||
)
|
||||
|
||||
if message.role == Role.user:
|
||||
role: Literal["user", "assistant"] = "user"
|
||||
elif message.role == Role.agent:
|
||||
role = "assistant"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported role: {message.role} in A2A message.",
|
||||
)
|
||||
|
||||
return Msg(
|
||||
name=name,
|
||||
role=role,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _guess_type(
|
||||
uri: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> Literal["image", "video", "audio", "unknown"]:
|
||||
"""Guess the content type from the uri or mime type.
|
||||
|
||||
Args:
|
||||
uri (`str | None`, optional):
|
||||
The uri of the content.
|
||||
mime_type (`str | None`, optional):
|
||||
The mime type of the content.
|
||||
|
||||
Returns:
|
||||
`Literal["image", "video", "audio", "unknown"]`:
|
||||
The guessed content type.
|
||||
"""
|
||||
if mime_type is None and uri is None:
|
||||
raise ValueError(
|
||||
"Either uri or mime_type must be provided to guess the"
|
||||
" content type.",
|
||||
)
|
||||
|
||||
if mime_type is None:
|
||||
mime_type, _encoding = mimetypes.guess_type(uri or "")
|
||||
|
||||
if isinstance(mime_type, str):
|
||||
if mime_type.startswith("image/"):
|
||||
return "image"
|
||||
|
||||
if mime_type.startswith("video/"):
|
||||
return "video"
|
||||
|
||||
if mime_type.startswith("audio/"):
|
||||
return "audio"
|
||||
|
||||
return "unknown"
|
||||
|
||||
async def format_a2a_task(self, name: str, task: Task) -> list[Msg]:
|
||||
"""Convert A2A Task object back to AgentScope Msg format.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name of the message sender.
|
||||
task (`Task`):
|
||||
The A2A Task object to be converted.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
Converted AgentScope Msg objects.
|
||||
"""
|
||||
msgs = []
|
||||
if task.status and task.status.message:
|
||||
msgs.append(
|
||||
await self.format_a2a_message(name, task.status.message),
|
||||
)
|
||||
|
||||
merged_msgs = []
|
||||
for msg in msgs:
|
||||
if merged_msgs and merged_msgs[-1].role == msg.role:
|
||||
merged_msgs[-1].content.extend(msg.content)
|
||||
|
||||
else:
|
||||
merged_msgs.append(msg)
|
||||
|
||||
if task.artifacts:
|
||||
for artifact in task.artifacts:
|
||||
artifact_content = [
|
||||
await self._format_a2a_part(_) for _ in artifact.parts
|
||||
]
|
||||
|
||||
if merged_msgs and merged_msgs[-1].role == "assistant":
|
||||
merged_msgs[-1].content.extend(artifact_content)
|
||||
merged_msgs[-1].metadata = artifact.metadata
|
||||
|
||||
else:
|
||||
merged_msgs.append(
|
||||
Msg(
|
||||
name=name,
|
||||
role="assistant",
|
||||
content=artifact_content,
|
||||
metadata=artifact.metadata,
|
||||
),
|
||||
)
|
||||
|
||||
return merged_msgs
|
||||
|
||||
async def _format_a2a_part(self, part: Part) -> ContentBlock:
|
||||
"""Convert a single A2A Part object into AgentScope ContentBlock.
|
||||
|
||||
.. note:: We will try to convert the `DataPart` into tool use and tool
|
||||
result blocks if possible.
|
||||
|
||||
Args:
|
||||
part (`Part`):
|
||||
The A2A Part object to be converted.
|
||||
|
||||
Returns:
|
||||
`ContentBlock`:
|
||||
The converted AgentScope ContentBlock.
|
||||
"""
|
||||
|
||||
from a2a.types import (
|
||||
TextPart,
|
||||
FilePart,
|
||||
FileWithUri,
|
||||
FileWithBytes,
|
||||
DataPart,
|
||||
)
|
||||
|
||||
if isinstance(part.root, TextPart):
|
||||
return TextBlock(
|
||||
type="text",
|
||||
text=part.root.text,
|
||||
)
|
||||
|
||||
if isinstance(part.root, FilePart):
|
||||
if isinstance(part.root.file, FileWithUri):
|
||||
return { # type: ignore[return-value, misc]
|
||||
"type": self._guess_type(
|
||||
part.root.file.uri,
|
||||
part.root.file.mime_type,
|
||||
),
|
||||
"source": URLSource(
|
||||
type="url",
|
||||
url=part.root.file.uri,
|
||||
),
|
||||
}
|
||||
|
||||
if isinstance(part.root.file, FileWithBytes):
|
||||
return { # type: ignore[return-value, misc]
|
||||
"type": self._guess_type(
|
||||
mime_type=part.root.file.mime_type,
|
||||
),
|
||||
"source": Base64Source(
|
||||
type="base64",
|
||||
media_type=part.root.file.mime_type
|
||||
or "application/octet-stream",
|
||||
data=part.root.file.bytes,
|
||||
),
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported File type: {type(part.root.file)} in A2A"
|
||||
"message.",
|
||||
)
|
||||
|
||||
if isinstance(part.root, DataPart):
|
||||
# Maybe the tool use and tool result blocks
|
||||
if {
|
||||
"type",
|
||||
"name",
|
||||
"input",
|
||||
"id",
|
||||
} <= part.root.data.keys() and part.root.data[
|
||||
"type"
|
||||
] == "tool_use":
|
||||
return part.root.data
|
||||
|
||||
if {
|
||||
"type",
|
||||
"name",
|
||||
"output",
|
||||
"id",
|
||||
} <= part.root.data.keys() and part.root.data[
|
||||
"type"
|
||||
] == "tool_result":
|
||||
return part.root.data
|
||||
|
||||
# TODO: what about the other data parts?
|
||||
return TextBlock(
|
||||
type="text",
|
||||
text=str(part.root.data),
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported Part type: {type(part.root)} in A2A message"
|
||||
f": {part.root}",
|
||||
)
|
||||
253
src/agentscope/formatter/_anthropic_formatter.py
Normal file
253
src/agentscope/formatter/_anthropic_formatter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The Anthropic formatter module."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
class AnthropicChatFormatter(TruncatedFormatterBase):
|
||||
"""The Anthropic formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into Anthropic API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
|
||||
.. note:: Anthropic suggests always passing all previous thinking
|
||||
blocks back to the API in subsequent calls to maintain reasoning
|
||||
continuity. For more details, please refer to
|
||||
`Anthropic's documentation
|
||||
<https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#preserving-thinking-blocks>`_.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for index, msg in enumerate(msgs):
|
||||
content_blocks = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ in ["thinking", "text", "image"]:
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
content_blocks.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "tool_use",
|
||||
"name": block.get("name"),
|
||||
"input": block.get("input", {}),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
output = block.get("output")
|
||||
if output is None:
|
||||
content_value = [{"type": "text", "text": None}]
|
||||
elif isinstance(output, list):
|
||||
content_value = output
|
||||
else:
|
||||
content_value = [{"type": "text", "text": str(output)}]
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.get("id"),
|
||||
"content": content_value,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
# Claude only allow the first message to be system message
|
||||
if msg.role == "system" and index != 0:
|
||||
role = "user"
|
||||
else:
|
||||
role = msg.role
|
||||
|
||||
msg_anthropic = {
|
||||
"role": role,
|
||||
"content": content_blocks or None,
|
||||
}
|
||||
|
||||
# When both content and tool_calls are None, skipped
|
||||
if msg_anthropic["content"] or msg_anthropic.get("tool_calls"):
|
||||
messages.append(msg_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class AnthropicMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
Anthropic formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Anthropic API."""
|
||||
return await AnthropicChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Anthropic API."""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required Anthropic format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
"type": "text",
|
||||
},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
conversation_blocks.append({**block})
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
"type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"type": "text",
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append(
|
||||
{"type": "text", "text": "</history>"},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_msgs
|
||||
639
src/agentscope/formatter/_dashscope_formatter.py
Normal file
639
src/agentscope/formatter/_dashscope_formatter.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The dashscope formatter module."""
|
||||
|
||||
import json
|
||||
import os.path
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from .._utils._common import _is_accessible_local_file
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
URLSource,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_dashscope_media_block(
|
||||
block: ImageBlock | AudioBlock,
|
||||
) -> dict[str, str]:
|
||||
"""Format an image or audio block for DashScope API.
|
||||
|
||||
Args:
|
||||
block (`ImageBlock` | `AudioBlock`):
|
||||
The image or audio block to format.
|
||||
|
||||
Returns:
|
||||
`dict[str, str]`:
|
||||
A dictionary with "image" or "audio" key and the formatted URL or
|
||||
data URI as value.
|
||||
|
||||
Raises:
|
||||
`NotImplementedError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
typ = block["type"]
|
||||
source = block["source"]
|
||||
if source["type"] == "url":
|
||||
url = source["url"]
|
||||
if _is_accessible_local_file(url):
|
||||
return {typ: "file://" + os.path.abspath(url)}
|
||||
else:
|
||||
# treat as web url
|
||||
return {typ: url}
|
||||
|
||||
elif source["type"] == "base64":
|
||||
media_type = source["media_type"]
|
||||
base64_data = source["data"]
|
||||
return {
|
||||
typ: f"data:{media_type};base64,{base64_data}",
|
||||
}
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported source type '{source.get('type')}' "
|
||||
f"for {typ} block.",
|
||||
)
|
||||
|
||||
|
||||
def _reformat_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Reformat the content to be compatible with HuggingFaceTokenCounter.
|
||||
|
||||
This function processes a list of messages and converts multi-part
|
||||
text content into single string content when all parts are plain text.
|
||||
This is necessary for compatibility with HuggingFaceTokenCounter which
|
||||
expects simple string content rather than structured content with
|
||||
multiple parts.
|
||||
|
||||
Args:
|
||||
messages (list[dict[str, Any]]):
|
||||
A list of message dictionaries where each message may contain a
|
||||
"content" field. The content can be either:
|
||||
- A string (unchanged)
|
||||
- A list of content items, where each item is a dict that may
|
||||
contain "text", "type", and other fields
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]:
|
||||
A list of reformatted messages. For messages where all content
|
||||
items are plain text (have "text" field and either no "type"
|
||||
field or "type" == "text"), the content list is converted to a
|
||||
single newline-joined string. Other messages remain unchanged.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Case 1: All text content - will be converted
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello", "type": "text"},
|
||||
{"text": "World", "type": "text"}
|
||||
]
|
||||
}
|
||||
]
|
||||
result = _reformat_messages(messages)
|
||||
print(result[0]["content"])
|
||||
# Output: "Hello\nWorld"
|
||||
|
||||
# Case 2: Mixed content - will remain unchanged
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello", "type": "text"},
|
||||
{"image_url": "...", "type": "image"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
result = _reformat_messages(messages) # remain unchanged
|
||||
print(type(result[0]["content"]))
|
||||
# Output: <class 'list'>
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
content = message.get("content", [])
|
||||
|
||||
is_all_text = True
|
||||
texts = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or "text" not in item:
|
||||
is_all_text = False
|
||||
break
|
||||
if "type" in item and item["type"] != "text":
|
||||
is_all_text = False
|
||||
break
|
||||
if item["text"]:
|
||||
texts.append(item["text"])
|
||||
|
||||
if is_all_text and texts:
|
||||
message["content"] = "\n".join(texts)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class DashScopeChatFormatter(TruncatedFormatterBase):
|
||||
"""The DashScope formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
|
||||
.. warning::
|
||||
Known Issues with DashScope API:
|
||||
|
||||
1. **Missing content field**: When messages lack the 'content' field,
|
||||
qwen-vl-max models will raise ``KeyError: 'content'``.
|
||||
|
||||
2. **None content value**: When content is ``None``, qwen-vl-max models
|
||||
will raise ``TypeError: 'NoneType' object is not iterable``.
|
||||
|
||||
3. **Empty text in content**: When content contains
|
||||
``[{"text": None}]``, qwen3-max may repeatedly invoke tools
|
||||
multiple times. Note that when qwen3-max initiates tool calls,
|
||||
the returned message contains ``"content": ""``.
|
||||
|
||||
To avoid these issues, this formatter assigns content as an empty
|
||||
list ``[]`` for messages without valid content blocks.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
promote_tool_result_audios: bool = False,
|
||||
promote_tool_result_videos: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_audios (`bool`, defaults to `False`):
|
||||
Whether to promote audios from tool results to user messages.
|
||||
Most LLM APIs don't support audios in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, audios are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_videos (`bool`, defaults to `False`):
|
||||
Whether to promote videos from tool results to user messages.
|
||||
Most LLM APIs don't support videos in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, videos are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter, max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
self.promote_tool_result_audios = promote_tool_result_audios
|
||||
self.promote_tool_result_videos = promote_tool_result_videos
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into DashScope API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
|
||||
if typ == "text":
|
||||
content_blocks.append(
|
||||
{
|
||||
"text": block.get("text"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ in ["image", "audio", "video"]:
|
||||
content_blocks.append(
|
||||
_format_dashscope_media_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
# First add the tool result message in DashScope API format
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": textual_output,
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Then, handle the multimodal data if any
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
elif (
|
||||
multimodal_block["type"] == "audio"
|
||||
and self.promote_tool_result_audios
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The audio from '{url}': ",
|
||||
),
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
elif (
|
||||
multimodal_block["type"] == "video"
|
||||
and self.promote_tool_result_videos
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The video from '{url}': ",
|
||||
),
|
||||
VideoBlock(
|
||||
type="video",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
f"the media contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
msg_dashscope = {
|
||||
"role": msg.role,
|
||||
"content": content_blocks,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_dashscope["tool_calls"] = tool_calls
|
||||
|
||||
if msg_dashscope["content"] or msg_dashscope.get("tool_calls"):
|
||||
formatted_msgs.append(msg_dashscope)
|
||||
|
||||
# Move to next message
|
||||
i += 1
|
||||
|
||||
return _reformat_messages(formatted_msgs)
|
||||
|
||||
|
||||
class DashScopeMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""DashScope formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
|
||||
.. note:: This formatter will combine previous messages (except tool
|
||||
calls/results) into a history section in the first system message with
|
||||
the conversation history prompt.
|
||||
|
||||
.. note:: For tool calls/results, they will be presented as separate
|
||||
messages as required by the DashScope API. Therefore, the tool calls/
|
||||
results messages are expected to be placed at the end of the input
|
||||
messages.
|
||||
|
||||
.. tip:: Telling the assistant's name in the system prompt is very
|
||||
important in multi-agent conversations. So that LLM can know who it
|
||||
is playing as.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
VideoBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
promote_tool_result_audios: bool = False,
|
||||
promote_tool_result_videos: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_audios (`bool`, defaults to `False`):
|
||||
Whether to promote audios from tool results to user messages.
|
||||
Most LLM APIs don't support audios in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, audios are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
promote_tool_result_videos (`bool`, defaults to `False`):
|
||||
Whether to promote videos from tool results to user messages.
|
||||
Most LLM APIs don't support videos in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, videos are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
self.promote_tool_result_audios = promote_tool_result_audios
|
||||
self.promote_tool_result_videos = promote_tool_result_videos
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the DashScope API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DashScope API.
|
||||
"""
|
||||
return await DashScopeChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
promote_tool_result_audios=self.promote_tool_result_audios,
|
||||
promote_tool_result_videos=self.promote_tool_result_videos,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into a user message with conversation history tags. For the
|
||||
first agent message, it will include the conversation history prompt.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DashScope API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required DashScope format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] in ["image", "audio", "video"]:
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
if block["source"]["type"] == "url":
|
||||
url = block["source"]["url"]
|
||||
if _is_accessible_local_file(url):
|
||||
conversation_blocks.append(
|
||||
{
|
||||
block["type"]: "file://"
|
||||
+ os.path.abspath(url),
|
||||
},
|
||||
)
|
||||
else:
|
||||
conversation_blocks.append({block["type"]: url})
|
||||
|
||||
elif block["source"]["type"] == "base64":
|
||||
media_type = block["source"]["media_type"]
|
||||
base64_data = block["source"]["data"]
|
||||
conversation_blocks.append(
|
||||
{
|
||||
block[
|
||||
"type"
|
||||
]: f"data:{media_type};base64,{base64_data}",
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, "
|
||||
"skipped.",
|
||||
block["type"],
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append({"text": "\n".join(accumulated_text)})
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
return _reformat_messages(formatted_msgs)
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for DashScope API."""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_text_content(),
|
||||
}
|
||||
265
src/agentscope/formatter/_deepseek_formatter.py
Normal file
265
src/agentscope/formatter/_deepseek_formatter.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The DeepSeek formatter module."""
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import Msg, TextBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
class DeepSeekChatFormatter(TruncatedFormatterBase):
|
||||
"""The DeepSeek formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = False
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into DeepSeek API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for msg in msgs:
|
||||
content_blocks: list = []
|
||||
reasoning_content_blocks: list = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
elif typ == "thinking":
|
||||
reasoning_content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
textual_output, _ = self.convert_tool_result_to_string(
|
||||
block.get("output"), # type: ignore[arg-type]
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": textual_output,
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
content_msg = "\n".join(
|
||||
content.get("text", "") for content in content_blocks
|
||||
)
|
||||
reasoning_msg = "\n".join(
|
||||
reasoning.get("thinking", "")
|
||||
for reasoning in reasoning_content_blocks
|
||||
)
|
||||
|
||||
msg_deepseek = {
|
||||
"role": msg.role,
|
||||
"content": content_msg or None,
|
||||
}
|
||||
|
||||
if reasoning_msg:
|
||||
msg_deepseek["reasoning_content"] = reasoning_msg
|
||||
|
||||
if tool_calls:
|
||||
msg_deepseek["tool_calls"] = tool_calls
|
||||
|
||||
if msg_deepseek["content"] or msg_deepseek.get("tool_calls"):
|
||||
messages.append(msg_deepseek)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class DeepSeekMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
DeepSeek formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = False
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DeepSeek multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the DeepSeek API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DeepSeek API.
|
||||
"""
|
||||
return await DeepSeekChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the DeepSeek API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DeepSeek API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required DeepSeek format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": conversation_blocks_text,
|
||||
}
|
||||
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
129
src/agentscope/formatter/_formatter_base.py
Normal file
129
src/agentscope/formatter/_formatter_base.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The formatter module."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, Tuple, Sequence
|
||||
|
||||
from .._utils._common import _save_base64_data
|
||||
from ..message import Msg, AudioBlock, ImageBlock, TextBlock, VideoBlock
|
||||
|
||||
|
||||
class FormatterBase:
|
||||
"""The base class for formatters."""
|
||||
|
||||
@abstractmethod
|
||||
async def format(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
"""Format the Msg objects to a list of dictionaries that satisfy the
|
||||
API requirements."""
|
||||
|
||||
@staticmethod
|
||||
def assert_list_of_msgs(msgs: list[Msg]) -> None:
|
||||
"""Assert that the input is a list of Msg objects.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be validated.
|
||||
"""
|
||||
if not isinstance(msgs, list):
|
||||
raise TypeError("Input must be a list of Msg objects.")
|
||||
|
||||
for msg in msgs:
|
||||
if not isinstance(msg, Msg):
|
||||
raise TypeError(
|
||||
f"Expected Msg object, got {type(msg)} instead.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_tool_result_to_string(
|
||||
output: str | List[TextBlock | ImageBlock | AudioBlock | VideoBlock],
|
||||
) -> tuple[
|
||||
str,
|
||||
Sequence[
|
||||
Tuple[
|
||||
str,
|
||||
ImageBlock | AudioBlock | TextBlock | VideoBlock,
|
||||
]
|
||||
],
|
||||
]:
|
||||
"""Turn the tool result list into a textual output to be compatible
|
||||
with the LLM API that doesn't support multimodal data in the tool
|
||||
result.
|
||||
|
||||
For URL-based images, the URL is included in the list. For
|
||||
base64-encoded images, the local file path where the image is saved
|
||||
is included in the returned list.
|
||||
|
||||
Args:
|
||||
output (`str | List[TextBlock | ImageBlock | AudioBlock | \
|
||||
VideoBlock]`):
|
||||
The output of the tool response, including text and multimodal
|
||||
data like images and audio.
|
||||
|
||||
Returns:
|
||||
`tuple[str, list[Tuple[str, ImageBlock | AudioBlock | VideoBlock \
|
||||
TextBlock]]]`:
|
||||
A tuple containing the textual representation of the tool
|
||||
result and a list of tuples. The first element of each tuple
|
||||
is the local file path or URL of the multimodal data, and the
|
||||
second element is the corresponding block.
|
||||
"""
|
||||
|
||||
if isinstance(output, str):
|
||||
return output, []
|
||||
|
||||
textual_output = []
|
||||
multimodal_data = []
|
||||
for block in output:
|
||||
assert isinstance(block, dict) and "type" in block, (
|
||||
f"Invalid block: {block}, a TextBlock, ImageBlock, "
|
||||
f"AudioBlock, or VideoBlock is expected."
|
||||
)
|
||||
if block["type"] == "text":
|
||||
textual_output.append(block["text"])
|
||||
|
||||
elif block["type"] in ["image", "audio", "video"]:
|
||||
assert "source" in block, (
|
||||
f"Invalid {block['type']} block: {block}, 'source' key "
|
||||
"is required."
|
||||
)
|
||||
source = block["source"]
|
||||
# Save the image locally and return the file path
|
||||
if source["type"] == "url":
|
||||
textual_output.append(
|
||||
f"The returned {block['type']} can be found "
|
||||
f"at: {source['url']}",
|
||||
)
|
||||
|
||||
path_multimodal_file = source["url"]
|
||||
|
||||
elif source["type"] == "base64":
|
||||
path_multimodal_file = _save_base64_data(
|
||||
source["media_type"],
|
||||
source["data"],
|
||||
)
|
||||
textual_output.append(
|
||||
f"The returned {block['type']} can be found "
|
||||
f"at: {path_multimodal_file}",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid image source: {block['source']}, "
|
||||
"expected 'url' or 'base64'.",
|
||||
)
|
||||
|
||||
multimodal_data.append(
|
||||
(path_multimodal_file, block),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported block type: {block['type']}, "
|
||||
"expected 'text', 'image', 'audio', or 'video'.",
|
||||
)
|
||||
|
||||
if len(textual_output) == 1:
|
||||
return textual_output[0], multimodal_data
|
||||
|
||||
else:
|
||||
return "\n".join("- " + _ for _ in textual_output), multimodal_data
|
||||
507
src/agentscope/formatter/_gemini_formatter.py
Normal file
507
src/agentscope/formatter/_gemini_formatter.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""Google gemini API formatter in agentscope."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._utils._common import _get_bytes_from_web_url
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
VideoBlock,
|
||||
URLSource,
|
||||
)
|
||||
from .._logging import logger
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_gemini_media_block(
|
||||
media_block: ImageBlock | AudioBlock | VideoBlock,
|
||||
) -> dict[str, Any]:
|
||||
"""Format an image/audio/video block for Gemini API.
|
||||
|
||||
Args:
|
||||
media_block (`ImageBlock | AudioBlock | VideoBlock`):
|
||||
The media block to format.
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`:
|
||||
A dictionary with "inline_data" key in Gemini format.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
source = media_block["source"]
|
||||
if source["type"] == "base64":
|
||||
return {
|
||||
"inline_data": {
|
||||
"data": source["data"],
|
||||
"mime_type": source["media_type"],
|
||||
},
|
||||
}
|
||||
elif source["type"] == "url":
|
||||
return {
|
||||
"inline_data": _to_gemini_inline_data(source["url"]),
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported source type: {source['type']}",
|
||||
)
|
||||
|
||||
|
||||
def _to_gemini_inline_data(url: str) -> dict:
|
||||
"""Convert url into the Gemini API required format."""
|
||||
parsed_url = urlparse(url)
|
||||
extension = url.split(".")[-1].lower()
|
||||
|
||||
# Pre-calculate media type from extension (image/audio/video).
|
||||
typ = None
|
||||
for k, v in GeminiChatFormatter.supported_extensions.items():
|
||||
if extension in v:
|
||||
typ = k
|
||||
break
|
||||
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
# Web url
|
||||
if typ is None:
|
||||
raise TypeError(
|
||||
f"Unsupported file extension: {extension}, expected "
|
||||
f"{GeminiChatFormatter.supported_extensions}",
|
||||
)
|
||||
|
||||
data = _get_bytes_from_web_url(url)
|
||||
return {
|
||||
"data": data,
|
||||
"mime_type": f"{typ}/{extension}",
|
||||
}
|
||||
|
||||
elif os.path.exists(url):
|
||||
# Local file
|
||||
if typ is None:
|
||||
raise TypeError(
|
||||
f"Unsupported file extension: {extension}, expected "
|
||||
f"{GeminiChatFormatter.supported_extensions}",
|
||||
)
|
||||
|
||||
with open(url, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"mime_type": f"{typ}/{extension}",
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
f"The URL `{url}` is not a valid image URL or local file.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiChatFormatter(TruncatedFormatterBase):
|
||||
"""The Gemini formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
supported_extensions: dict[str, list[str]] = {
|
||||
"image": ["png", "jpeg", "webp", "heic", "heif"],
|
||||
"video": [
|
||||
"mp4",
|
||||
"mpeg",
|
||||
"mov",
|
||||
"avi",
|
||||
"x-flv",
|
||||
"mpg",
|
||||
"webm",
|
||||
"wmv",
|
||||
"3gpp",
|
||||
],
|
||||
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter, max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict]:
|
||||
"""Format message objects into Gemini API required format."""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list = []
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
parts = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
parts.append(
|
||||
{
|
||||
"text": block.get("text"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_use":
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"id": None,
|
||||
"name": block["name"],
|
||||
"args": block["input"],
|
||||
},
|
||||
"thought_signature": block.get("id", None),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
# First add the tool result message in DashScope API format
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"id": block["id"],
|
||||
"name": block["name"],
|
||||
"response": {
|
||||
"output": textual_output,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
"the image contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
elif typ in ["image", "audio", "video"]:
|
||||
parts.append(
|
||||
_format_gemini_media_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type: %s in the message, skipped. ",
|
||||
typ,
|
||||
)
|
||||
|
||||
role = "model" if msg.role == "assistant" else "user"
|
||||
|
||||
if parts:
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"parts": parts,
|
||||
},
|
||||
)
|
||||
|
||||
# Move to next message (including inserted messages, which will
|
||||
# be processed in subsequent iterations)
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""The multi-agent formatter for Google Gemini API, where more than a
|
||||
user and an agent are involved.
|
||||
|
||||
.. note:: This formatter will combine previous messages (except tool
|
||||
calls/results) into a history section in the first system message with
|
||||
the conversation history prompt.
|
||||
|
||||
.. note:: For tool calls/results, they will be presented as separate
|
||||
messages as required by the Gemini API. Therefore, the tool calls/
|
||||
results messages are expected to be placed at the end of the input
|
||||
messages.
|
||||
|
||||
.. tip:: Telling the assistant's name in the system prompt is very
|
||||
important in multi-agent conversations. So that LLM can know who it
|
||||
is playing as.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to be used for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the Gemini API."""
|
||||
return {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"text": msg.get_text_content(),
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Gemini API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Gemini API.
|
||||
"""
|
||||
return await GeminiChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Gemini API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Gemini API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into Gemini API required format
|
||||
formatted_msgs: list = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_parts: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] in ["image", "video", "audio"]:
|
||||
# handle the accumulated text as a single part if exists
|
||||
if accumulated_text:
|
||||
conversation_parts.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
# handle the multimodal data
|
||||
conversation_parts.append(
|
||||
_format_gemini_media_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_parts.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
},
|
||||
)
|
||||
|
||||
# Add prompt and <history></history> tags around conversation history
|
||||
if conversation_parts:
|
||||
if conversation_parts[0].get("text"):
|
||||
conversation_parts[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>"
|
||||
+ conversation_parts[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_parts.insert(
|
||||
0,
|
||||
{"text": conversation_history_prompt + "<history>"},
|
||||
)
|
||||
|
||||
if conversation_parts[-1].get("text"):
|
||||
conversation_parts[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_parts.append(
|
||||
{"text": "</history>"},
|
||||
)
|
||||
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": conversation_parts,
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_msgs
|
||||
441
src/agentscope/formatter/_ollama_formatter.py
Normal file
441
src/agentscope/formatter/_ollama_formatter.py
Normal file
@@ -0,0 +1,441 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The Ollama formatter module."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from .._utils._common import _get_bytes_from_web_url
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
URLSource,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_ollama_image_block(
|
||||
image_block: ImageBlock,
|
||||
) -> str:
|
||||
"""Format an image block for Ollama API.
|
||||
|
||||
Args:
|
||||
image_block (`ImageBlock`):
|
||||
The image block to format.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
Base64 encoded image data as a string.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
source = image_block["source"]
|
||||
if source["type"] == "url":
|
||||
return _convert_ollama_image_url_to_base64_data(source["url"])
|
||||
elif source["type"] == "base64":
|
||||
return source["data"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image source type: {source['type']}",
|
||||
)
|
||||
|
||||
|
||||
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
|
||||
"""Convert image url to base64."""
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
# Web url
|
||||
data = _get_bytes_from_web_url(url)
|
||||
return data
|
||||
if os.path.exists(url):
|
||||
# Local file
|
||||
with open(url, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
return data
|
||||
|
||||
raise ValueError(
|
||||
f"The URL `{url}` is not a valid image URL or local file.",
|
||||
)
|
||||
|
||||
|
||||
class OllamaChatFormatter(TruncatedFormatterBase):
|
||||
"""The Ollama formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `role` field to identify different
|
||||
participants in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Ollama chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter, max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into Ollama API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list = []
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
content_blocks: list = []
|
||||
tool_calls = []
|
||||
images = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": block.get("input", {}),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": textual_output,
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Then, handle the multimodal data if any
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
"the image contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "image":
|
||||
images.append(
|
||||
_format_ollama_image_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
content_msg = "\n".join(
|
||||
content.get("text", "") for content in content_blocks
|
||||
)
|
||||
msg_ollama = {
|
||||
"role": msg.role,
|
||||
"content": content_msg or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_ollama["tool_calls"] = tool_calls
|
||||
|
||||
if images:
|
||||
msg_ollama["images"] = images
|
||||
|
||||
if (
|
||||
msg_ollama["content"]
|
||||
or msg_ollama.get("images")
|
||||
or msg_ollama.get("tool_calls")
|
||||
):
|
||||
messages.append(msg_ollama)
|
||||
|
||||
# Move to next message
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
Ollama formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Ollama multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the Ollama API."""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_text_content(),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Ollama API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Ollama API.
|
||||
"""
|
||||
return await OllamaChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Ollama API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the ollama API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required Ollama format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
images = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
images.append(_format_ollama_image_block(block))
|
||||
conversation_blocks.append({**block})
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": conversation_blocks_text,
|
||||
}
|
||||
if images:
|
||||
user_message["images"] = images
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
530
src/agentscope/formatter/_openai_formatter.py
Normal file
530
src/agentscope/formatter/_openai_formatter.py
Normal file
@@ -0,0 +1,530 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches, too-many-nested-blocks
|
||||
"""The OpenAI formatter for agentscope."""
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import (
|
||||
Msg,
|
||||
URLSource,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
Base64Source,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _format_openai_image_block(
|
||||
image_block: ImageBlock,
|
||||
) -> dict[str, Any]:
|
||||
"""Format an image block for OpenAI API.
|
||||
|
||||
Args:
|
||||
image_block (`ImageBlock`):
|
||||
The image block to format.
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`:
|
||||
A dictionary with "type" and "image_url" keys in OpenAI format.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the source type is not supported.
|
||||
"""
|
||||
source = image_block["source"]
|
||||
if source["type"] == "url":
|
||||
url = _to_openai_image_url(source["url"])
|
||||
elif source["type"] == "base64":
|
||||
data = source["data"]
|
||||
media_type = source["media_type"]
|
||||
url = f"data:{media_type};base64,{data}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image source type: {source['type']}",
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _to_openai_image_url(url: str) -> str:
|
||||
"""Convert an image url to openai format. If the given url is a local
|
||||
file, it will be converted to base64 format. Otherwise, it will be
|
||||
returned directly.
|
||||
|
||||
Args:
|
||||
url (`str`):
|
||||
The local or public url of the image.
|
||||
"""
|
||||
# See https://platform.openai.com/docs/guides/vision for details of
|
||||
# support image extensions.
|
||||
support_image_extensions = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
lower_url = url.lower()
|
||||
|
||||
# Web url
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
path_lower = parsed_url.path if parsed_url.path else parsed_url.netloc
|
||||
if any(path_lower.endswith(_) for _ in support_image_extensions):
|
||||
return url
|
||||
|
||||
# Check if it is a local file
|
||||
elif os.path.exists(url) and os.path.isfile(url):
|
||||
if any(lower_url.endswith(_) for _ in support_image_extensions):
|
||||
with open(url, "rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode(
|
||||
"utf-8",
|
||||
)
|
||||
extension = parsed_url.path.lower().split(".")[-1]
|
||||
mime_type = f"image/{extension}"
|
||||
return f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
raise TypeError(f'"{url}" should end with {support_image_extensions}.')
|
||||
|
||||
|
||||
def _to_openai_audio_data(source: URLSource | Base64Source) -> dict:
|
||||
"""Covert an audio source to OpenAI format."""
|
||||
if source["type"] == "url":
|
||||
extension = source["url"].split(".")[-1].lower()
|
||||
if extension not in ["wav", "mp3"]:
|
||||
raise TypeError(
|
||||
f"Unsupported audio file extension: {extension}, "
|
||||
"wav and mp3 are supported.",
|
||||
)
|
||||
|
||||
parsed_url = urlparse(source["url"])
|
||||
|
||||
if os.path.exists(source["url"]):
|
||||
with open(source["url"], "rb") as audio_file:
|
||||
data = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||
|
||||
# web url
|
||||
elif parsed_url.scheme != "":
|
||||
response = requests.get(source["url"])
|
||||
response.raise_for_status()
|
||||
data = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported audio source: {source['url']}, "
|
||||
"it should be a local file or a web URL.",
|
||||
)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"format": extension,
|
||||
}
|
||||
|
||||
if source["type"] == "base64":
|
||||
data = source["data"]
|
||||
media_type = source["media_type"]
|
||||
|
||||
if media_type not in ["audio/wav", "audio/mp3"]:
|
||||
raise TypeError(
|
||||
f"Unsupported audio media type: {media_type}, "
|
||||
"only audio/wav and audio/mp3 are supported.",
|
||||
)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"format": media_type.split("/")[-1],
|
||||
}
|
||||
|
||||
raise TypeError(f"Unsupported audio source: {source['type']}.")
|
||||
|
||||
|
||||
class OpenAIChatFormatter(TruncatedFormatterBase):
|
||||
"""The OpenAI formatter class for chatbot scenario, where only a user
|
||||
and an agent are involved. We use the `name` field in OpenAI API to
|
||||
identify different entities in the conversation.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversation"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision models"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""Supported message blocks for OpenAI API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI chat formatter.
|
||||
|
||||
Args:
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into OpenAI API required format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of Msg objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries, where each dictionary has "name",
|
||||
"role", and "content" keys.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
i = 0
|
||||
while i < len(msgs):
|
||||
msg = msgs[i]
|
||||
content_blocks = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
(
|
||||
textual_output,
|
||||
multimodal_data,
|
||||
) = self.convert_tool_result_to_string(block["output"])
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": ( # type: ignore[arg-type]
|
||||
textual_output
|
||||
),
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
# Then, handle the multimodal data if any
|
||||
promoted_blocks: list = []
|
||||
for url, multimodal_block in multimodal_data:
|
||||
if (
|
||||
multimodal_block["type"] == "image"
|
||||
and self.promote_tool_result_images
|
||||
):
|
||||
promoted_blocks.extend(
|
||||
[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"\n- The image from '{url}': ",
|
||||
),
|
||||
ImageBlock(
|
||||
type="image",
|
||||
source=URLSource(
|
||||
type="url",
|
||||
url=url,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
if promoted_blocks:
|
||||
# Insert promoted blocks as new user message(s)
|
||||
promoted_blocks = [
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="<system-info>The following are "
|
||||
"the image contents from the tool "
|
||||
f"result of '{block['name']}':",
|
||||
),
|
||||
*promoted_blocks,
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="</system-info>",
|
||||
),
|
||||
]
|
||||
|
||||
msgs.insert(
|
||||
i + 1,
|
||||
Msg(
|
||||
name="user",
|
||||
content=promoted_blocks,
|
||||
role="user",
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "image":
|
||||
content_blocks.append(
|
||||
_format_openai_image_block(
|
||||
block, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
elif typ == "audio":
|
||||
# Filter out audio content when the multimodal model
|
||||
# outputs both text and audio, to prevent errors in
|
||||
# subsequent model calls
|
||||
if msg.role == "assistant":
|
||||
continue
|
||||
input_audio = _to_openai_audio_data(block["source"])
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": input_audio,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
msg_openai = {
|
||||
"role": msg.role,
|
||||
"name": msg.name,
|
||||
"content": content_blocks or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_openai["tool_calls"] = tool_calls
|
||||
|
||||
# When both content and tool_calls are None, skipped
|
||||
if msg_openai["content"] or msg_openai.get("tool_calls"):
|
||||
messages.append(msg_openai)
|
||||
|
||||
# Move to next message
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OpenAIMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
OpenAI formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
.. tip:: This formatter is compatible with OpenAI API and
|
||||
OpenAI-compatible services like vLLM, Azure OpenAI, and others.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversation"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision models"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""Supported message blocks for OpenAI API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
promote_tool_result_images: bool = False,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
promote_tool_result_images (`bool`, defaults to `False`):
|
||||
Whether to promote images from tool results to user messages.
|
||||
Most LLM APIs don't support images in tool result blocks, but
|
||||
do support them in user message blocks. When `True`, images are
|
||||
extracted and appended as a separate user message with
|
||||
explanatory text indicating their source.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
self.promote_tool_result_images = promote_tool_result_images
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the OpenAI API."""
|
||||
return await OpenAIChatFormatter(
|
||||
promote_tool_result_images=self.promote_tool_result_images,
|
||||
).format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the OpenAI API."""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required OpenAI format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
images = []
|
||||
audios = []
|
||||
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
images.append(_format_openai_image_block(block))
|
||||
elif block["type"] == "audio":
|
||||
# Filter out audio content when the multimodal model
|
||||
# outputs both text and audio, to prevent errors in
|
||||
# subsequent model calls
|
||||
if msg.role == "assistant":
|
||||
continue
|
||||
input_audio = _to_openai_audio_data(block["source"])
|
||||
audios.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": input_audio,
|
||||
},
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
content_list: list[dict[str, Any]] = []
|
||||
if conversation_blocks_text:
|
||||
content_list.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": conversation_blocks_text,
|
||||
},
|
||||
)
|
||||
if images:
|
||||
content_list.extend(images)
|
||||
if audios:
|
||||
content_list.extend(audios)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": content_list,
|
||||
}
|
||||
|
||||
if content_list:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
297
src/agentscope/formatter/_truncated_formatter_base.py
Normal file
297
src/agentscope/formatter/_truncated_formatter_base.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The truncated formatter base class, which allows to truncate the input
|
||||
messages."""
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Tuple,
|
||||
Literal,
|
||||
AsyncGenerator,
|
||||
)
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ..message import Msg
|
||||
from ..token import TokenCounterBase
|
||||
from ..tracing import trace_format
|
||||
|
||||
|
||||
class TruncatedFormatterBase(FormatterBase, ABC):
|
||||
"""Base class for truncated formatters, which formats input messages into
|
||||
required formats with tokens under a specified limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the TruncatedFormatterBase.
|
||||
|
||||
Args:
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
self.token_counter = token_counter
|
||||
|
||||
assert (
|
||||
max_tokens is None or 0 < max_tokens
|
||||
), "max_tokens must be greater than 0"
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@trace_format
|
||||
async def format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
**kwargs: Any,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. If token
|
||||
counter and max token limit are provided, the messages will be
|
||||
truncated to fit the limit.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be formatted.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages in the required format.
|
||||
"""
|
||||
|
||||
# Check if the input messages are valid
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
msgs = deepcopy(msgs)
|
||||
|
||||
while True:
|
||||
formatted_msgs = await self._format(msgs)
|
||||
n_tokens = await self._count(formatted_msgs)
|
||||
|
||||
if (
|
||||
n_tokens is None
|
||||
or self.max_tokens is None
|
||||
or n_tokens <= self.max_tokens
|
||||
):
|
||||
return formatted_msgs
|
||||
|
||||
# truncate the input messages
|
||||
msgs = await self._truncate(msgs)
|
||||
|
||||
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. This method
|
||||
should be implemented by the subclasses."""
|
||||
|
||||
formatted_msgs = []
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
formatted_msgs.append(
|
||||
await self._format_system_message(msgs[0]),
|
||||
)
|
||||
start_index = 1
|
||||
|
||||
is_first_agent_message = True
|
||||
async for typ, group in self._group_messages(msgs[start_index:]):
|
||||
match typ:
|
||||
case "tool_sequence":
|
||||
formatted_msgs.extend(
|
||||
await self._format_tool_sequence(group),
|
||||
)
|
||||
case "agent_message":
|
||||
formatted_msgs.extend(
|
||||
await self._format_agent_message(
|
||||
group,
|
||||
is_first_agent_message,
|
||||
),
|
||||
)
|
||||
is_first_agent_message = False
|
||||
|
||||
return formatted_msgs
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the LLM API.
|
||||
|
||||
.. note:: This is the default implementation. For certain LLM APIs
|
||||
with specific requirements, you may need to implement a custom
|
||||
formatting function to accommodate those particular needs.
|
||||
"""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_content_blocks("text"),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_tool_sequence is not implemented",
|
||||
)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_agent_message is not implemented",
|
||||
)
|
||||
|
||||
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
|
||||
"""Truncate the input messages, so that it can fit the token limit.
|
||||
This function is called only when
|
||||
|
||||
- both `token_counter` and `max_tokens` are provided,
|
||||
- the formatted output of the input messages exceeds the token limit.
|
||||
|
||||
.. tip:: This function only provides a simple strategy, and developers
|
||||
can override this method to implement more sophisticated
|
||||
truncation strategies.
|
||||
|
||||
.. note:: The tool call message should be truncated together with
|
||||
its corresponding tool result message to satisfy the LLM API
|
||||
requirements.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be truncated.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the system prompt message already exceeds the token limit,
|
||||
or if there are tool calls without corresponding tool results.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The truncated messages.
|
||||
"""
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
if len(msgs) == 1:
|
||||
# If the system prompt already exceeds the token limit, we
|
||||
# raise an error.
|
||||
raise ValueError(
|
||||
f"The system prompt message already exceeds the token "
|
||||
f"limit ({self.max_tokens} tokens).",
|
||||
)
|
||||
|
||||
start_index = 1
|
||||
|
||||
# Create a tool call IDs queues to delete the corresponding tool
|
||||
# result message
|
||||
tool_call_ids = set()
|
||||
for i in range(start_index, len(msgs)):
|
||||
msg = msgs[i]
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_ids.add(block["id"])
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
try:
|
||||
tool_call_ids.remove(block["id"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We can stop truncating if the queue is empty
|
||||
if len(tool_call_ids) == 0:
|
||||
return msgs[:start_index] + msgs[i + 1 :]
|
||||
|
||||
if len(tool_call_ids) > 0:
|
||||
raise ValueError(
|
||||
"The input messages contains tool call(s) that do not have "
|
||||
f"the corresponding tool result(s): {tool_call_ids}. ",
|
||||
)
|
||||
|
||||
return msgs[:start_index]
|
||||
|
||||
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
|
||||
"""Count the number of tokens in the input messages. If token counter
|
||||
is not provided, `None` will be returned.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to count tokens for.
|
||||
"""
|
||||
if self.token_counter is None:
|
||||
return None
|
||||
|
||||
return await self.token_counter.count(msgs)
|
||||
|
||||
@staticmethod
|
||||
async def _group_messages(
|
||||
msgs: list[Msg],
|
||||
) -> AsyncGenerator[
|
||||
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
|
||||
None,
|
||||
]:
|
||||
"""Group the input messages into two types and yield them as a
|
||||
generator. The two types are:
|
||||
|
||||
- agent message that doesn't contain tool calls/results, and
|
||||
- tool sequence that consisted of a sequence of tool calls/results
|
||||
|
||||
.. note:: The group operation is used in multi-agent scenario, where
|
||||
multiple entities are involved in the input messages. So that to be
|
||||
compatible with tools API, we have to group the messages and format
|
||||
them with different strategies.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be grouped, where the system prompt
|
||||
message shouldn't be included.
|
||||
|
||||
Yields:
|
||||
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
|
||||
A generator that yields tuples of group type and the list of
|
||||
messages in that group. The group type can be either
|
||||
"tool_sequence" or "agent_message".
|
||||
"""
|
||||
|
||||
group_type: Literal["tool_sequence", "agent_message"] | None = None
|
||||
group = []
|
||||
for msg in msgs:
|
||||
if group_type is None:
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group_type = "tool_sequence"
|
||||
else:
|
||||
group_type = "agent_message"
|
||||
|
||||
group.append(msg)
|
||||
continue
|
||||
|
||||
# determine if this msg has the same type as the current group
|
||||
if group_type == "tool_sequence":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group.append(msg)
|
||||
|
||||
else:
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "agent_message"
|
||||
|
||||
elif group_type == "agent_message":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "tool_sequence"
|
||||
|
||||
else:
|
||||
group.append(msg)
|
||||
if group_type:
|
||||
yield group_type, group
|
||||
Reference in New Issue
Block a user