chore: initialize sandbox and overwrite remote content
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
This commit is contained in:
297
src/agentscope/formatter/_truncated_formatter_base.py
Normal file
297
src/agentscope/formatter/_truncated_formatter_base.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The truncated formatter base class, which allows to truncate the input
|
||||
messages."""
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Tuple,
|
||||
Literal,
|
||||
AsyncGenerator,
|
||||
)
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ..message import Msg
|
||||
from ..token import TokenCounterBase
|
||||
from ..tracing import trace_format
|
||||
|
||||
|
||||
class TruncatedFormatterBase(FormatterBase, ABC):
|
||||
"""Base class for truncated formatters, which formats input messages into
|
||||
required formats with tokens under a specified limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the TruncatedFormatterBase.
|
||||
|
||||
Args:
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
self.token_counter = token_counter
|
||||
|
||||
assert (
|
||||
max_tokens is None or 0 < max_tokens
|
||||
), "max_tokens must be greater than 0"
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@trace_format
|
||||
async def format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
**kwargs: Any,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. If token
|
||||
counter and max token limit are provided, the messages will be
|
||||
truncated to fit the limit.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be formatted.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages in the required format.
|
||||
"""
|
||||
|
||||
# Check if the input messages are valid
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
msgs = deepcopy(msgs)
|
||||
|
||||
while True:
|
||||
formatted_msgs = await self._format(msgs)
|
||||
n_tokens = await self._count(formatted_msgs)
|
||||
|
||||
if (
|
||||
n_tokens is None
|
||||
or self.max_tokens is None
|
||||
or n_tokens <= self.max_tokens
|
||||
):
|
||||
return formatted_msgs
|
||||
|
||||
# truncate the input messages
|
||||
msgs = await self._truncate(msgs)
|
||||
|
||||
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. This method
|
||||
should be implemented by the subclasses."""
|
||||
|
||||
formatted_msgs = []
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
formatted_msgs.append(
|
||||
await self._format_system_message(msgs[0]),
|
||||
)
|
||||
start_index = 1
|
||||
|
||||
is_first_agent_message = True
|
||||
async for typ, group in self._group_messages(msgs[start_index:]):
|
||||
match typ:
|
||||
case "tool_sequence":
|
||||
formatted_msgs.extend(
|
||||
await self._format_tool_sequence(group),
|
||||
)
|
||||
case "agent_message":
|
||||
formatted_msgs.extend(
|
||||
await self._format_agent_message(
|
||||
group,
|
||||
is_first_agent_message,
|
||||
),
|
||||
)
|
||||
is_first_agent_message = False
|
||||
|
||||
return formatted_msgs
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the LLM API.
|
||||
|
||||
.. note:: This is the default implementation. For certain LLM APIs
|
||||
with specific requirements, you may need to implement a custom
|
||||
formatting function to accommodate those particular needs.
|
||||
"""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_content_blocks("text"),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_tool_sequence is not implemented",
|
||||
)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_agent_message is not implemented",
|
||||
)
|
||||
|
||||
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
|
||||
"""Truncate the input messages, so that it can fit the token limit.
|
||||
This function is called only when
|
||||
|
||||
- both `token_counter` and `max_tokens` are provided,
|
||||
- the formatted output of the input messages exceeds the token limit.
|
||||
|
||||
.. tip:: This function only provides a simple strategy, and developers
|
||||
can override this method to implement more sophisticated
|
||||
truncation strategies.
|
||||
|
||||
.. note:: The tool call message should be truncated together with
|
||||
its corresponding tool result message to satisfy the LLM API
|
||||
requirements.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be truncated.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the system prompt message already exceeds the token limit,
|
||||
or if there are tool calls without corresponding tool results.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The truncated messages.
|
||||
"""
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
if len(msgs) == 1:
|
||||
# If the system prompt already exceeds the token limit, we
|
||||
# raise an error.
|
||||
raise ValueError(
|
||||
f"The system prompt message already exceeds the token "
|
||||
f"limit ({self.max_tokens} tokens).",
|
||||
)
|
||||
|
||||
start_index = 1
|
||||
|
||||
# Create a tool call IDs queues to delete the corresponding tool
|
||||
# result message
|
||||
tool_call_ids = set()
|
||||
for i in range(start_index, len(msgs)):
|
||||
msg = msgs[i]
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_ids.add(block["id"])
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
try:
|
||||
tool_call_ids.remove(block["id"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We can stop truncating if the queue is empty
|
||||
if len(tool_call_ids) == 0:
|
||||
return msgs[:start_index] + msgs[i + 1 :]
|
||||
|
||||
if len(tool_call_ids) > 0:
|
||||
raise ValueError(
|
||||
"The input messages contains tool call(s) that do not have "
|
||||
f"the corresponding tool result(s): {tool_call_ids}. ",
|
||||
)
|
||||
|
||||
return msgs[:start_index]
|
||||
|
||||
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
|
||||
"""Count the number of tokens in the input messages. If token counter
|
||||
is not provided, `None` will be returned.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to count tokens for.
|
||||
"""
|
||||
if self.token_counter is None:
|
||||
return None
|
||||
|
||||
return await self.token_counter.count(msgs)
|
||||
|
||||
@staticmethod
|
||||
async def _group_messages(
|
||||
msgs: list[Msg],
|
||||
) -> AsyncGenerator[
|
||||
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
|
||||
None,
|
||||
]:
|
||||
"""Group the input messages into two types and yield them as a
|
||||
generator. The two types are:
|
||||
|
||||
- agent message that doesn't contain tool calls/results, and
|
||||
- tool sequence that consisted of a sequence of tool calls/results
|
||||
|
||||
.. note:: The group operation is used in multi-agent scenario, where
|
||||
multiple entities are involved in the input messages. So that to be
|
||||
compatible with tools API, we have to group the messages and format
|
||||
them with different strategies.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be grouped, where the system prompt
|
||||
message shouldn't be included.
|
||||
|
||||
Yields:
|
||||
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
|
||||
A generator that yields tuples of group type and the list of
|
||||
messages in that group. The group type can be either
|
||||
"tool_sequence" or "agent_message".
|
||||
"""
|
||||
|
||||
group_type: Literal["tool_sequence", "agent_message"] | None = None
|
||||
group = []
|
||||
for msg in msgs:
|
||||
if group_type is None:
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group_type = "tool_sequence"
|
||||
else:
|
||||
group_type = "agent_message"
|
||||
|
||||
group.append(msg)
|
||||
continue
|
||||
|
||||
# determine if this msg has the same type as the current group
|
||||
if group_type == "tool_sequence":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group.append(msg)
|
||||
|
||||
else:
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "agent_message"
|
||||
|
||||
elif group_type == "agent_message":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "tool_sequence"
|
||||
|
||||
else:
|
||||
group.append(msg)
|
||||
if group_type:
|
||||
yield group_type, group
|
||||
Reference in New Issue
Block a user