Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
320 lines
10 KiB
Python
320 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Utility functions for A2UI agent integration."""
|
|
import json
|
|
from typing import Any
|
|
from pydantic import BaseModel
|
|
from pydantic import Field
|
|
|
|
from a2a.types import (
|
|
DataPart,
|
|
TextPart,
|
|
Message,
|
|
Part,
|
|
)
|
|
from a2ui.extension.a2ui_extension import (
|
|
A2UI_MIME_TYPE,
|
|
MIME_TYPE_KEY,
|
|
A2UI_EXTENSION_URI,
|
|
)
|
|
|
|
from agentscope._logging import logger
|
|
|
|
|
|
class A2UIResponse(BaseModel):
|
|
"""Response model for A2UI formatted output."""
|
|
|
|
response_with_a2ui: str = Field(
|
|
description="The response with A2UI JSON",
|
|
)
|
|
|
|
|
|
def check_a2ui_extension(*args: Any) -> bool:
|
|
"""Check if a2ui extension is requested in the request context.
|
|
|
|
Args:
|
|
*args: Variable arguments that may contain ServerCallContext as the
|
|
first element.
|
|
|
|
Returns:
|
|
True if a2ui extension is requested and activated, False otherwise.
|
|
"""
|
|
# Extract context from args (ServerCallContext is typically the first
|
|
# element)
|
|
if not args or len(args) == 0:
|
|
logger.warning("check_a2ui_extension: No context provided in args")
|
|
return False
|
|
|
|
context = args[0]
|
|
|
|
# Check if context has requested_extensions attribute
|
|
if not hasattr(context, "requested_extensions"):
|
|
logger.warning(
|
|
"check_a2ui_extension: Context does not have "
|
|
"requested_extensions attribute",
|
|
)
|
|
return False
|
|
|
|
# Check if A2UI extension is requested
|
|
if A2UI_EXTENSION_URI in context.requested_extensions:
|
|
# Activate the extension if add_activated_extension method exists
|
|
if hasattr(context, "add_activated_extension"):
|
|
context.add_activated_extension(A2UI_EXTENSION_URI)
|
|
logger.info("A2UI extension activated: %s", A2UI_EXTENSION_URI)
|
|
else:
|
|
logger.warning(
|
|
"check_a2ui_extension: Context does not have "
|
|
"add_activated_extension method",
|
|
)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def transfer_ui_event_to_query(ui_event_part: dict) -> str:
|
|
"""Transfer UI event to a query string.
|
|
|
|
Args:
|
|
ui_event_part: A dictionary containing UI event information with
|
|
actionName and context.
|
|
|
|
Returns:
|
|
A formatted query string based on the UI event action.
|
|
"""
|
|
action = ui_event_part.get("actionName")
|
|
ctx = ui_event_part.get("context", {})
|
|
|
|
if action in ["book_restaurant", "select_item"]:
|
|
restaurant_name = ctx.get("restaurantName", "Unknown Restaurant")
|
|
address = ctx.get("address", "Address not provided")
|
|
image_url = ctx.get("imageUrl", "")
|
|
query = (
|
|
f"USER_WANTS_TO_BOOK: {restaurant_name}, "
|
|
f"Address: {address}, ImageURL: {image_url}"
|
|
)
|
|
elif action == "submit_booking":
|
|
restaurant_name = ctx.get("restaurantName", "Unknown Restaurant")
|
|
party_size = ctx.get("partySize", "Unknown Size")
|
|
reservation_time = ctx.get("reservationTime", "Unknown Time")
|
|
dietary_reqs = ctx.get("dietary", "None")
|
|
image_url = ctx.get("imageUrl", "")
|
|
query = (
|
|
f"User submitted a booking for {restaurant_name} "
|
|
f"for {party_size} people at {reservation_time} "
|
|
f"with dietary requirements: {dietary_reqs}. "
|
|
f"The image URL is {image_url}"
|
|
)
|
|
else:
|
|
# Note: The A2UI original example uses `ctx` as the data source.
|
|
# However, in generated UI components, the `ctx` field may be empty
|
|
# when the databinding path cannot be resolved. To ensure we capture
|
|
# all available event data, we use the entire `ui_event_part` instead.
|
|
query = f"User submitted an event: {action} with data: {ui_event_part}"
|
|
|
|
return query
|
|
|
|
|
|
def pre_process_request_with_ui_event(message: Message) -> Any:
|
|
"""Pre-process the request.
|
|
|
|
Args:
|
|
message: The agent request object.
|
|
|
|
Returns:
|
|
The pre-processed request.
|
|
"""
|
|
|
|
if message and message.parts:
|
|
logger.info(
|
|
"--- AGENT_EXECUTOR: Processing %s message parts ---",
|
|
len(message.parts),
|
|
)
|
|
for i, part in enumerate(message.parts):
|
|
if isinstance(part.root, DataPart):
|
|
if "userAction" in part.root.data:
|
|
logger.info(
|
|
" Part %s: Found a2ui UI ClientEvent payload: %s",
|
|
i,
|
|
json.dumps(part.root.data["userAction"], indent=4),
|
|
)
|
|
ui_event_part = part.root.data["userAction"]
|
|
message.parts[i] = Part(
|
|
root=TextPart(
|
|
text=transfer_ui_event_to_query(ui_event_part),
|
|
),
|
|
)
|
|
return message
|
|
|
|
|
|
def _find_json_end(json_string: str) -> int:
|
|
"""Find the end position of a JSON array or object.
|
|
|
|
Finds the end by matching brackets/braces.
|
|
|
|
Args:
|
|
json_string: The JSON string to search.
|
|
|
|
Returns:
|
|
The end position (index + 1) of the JSON structure.
|
|
"""
|
|
if json_string.startswith("["):
|
|
# Find matching closing bracket
|
|
bracket_count = 0
|
|
for i, char in enumerate(json_string):
|
|
if char == "[":
|
|
bracket_count += 1
|
|
elif char == "]":
|
|
bracket_count -= 1
|
|
if bracket_count == 0:
|
|
return i + 1
|
|
elif json_string.startswith("{"):
|
|
# Find matching closing brace
|
|
brace_count = 0
|
|
for i, char in enumerate(json_string):
|
|
if char == "{":
|
|
brace_count += 1
|
|
elif char == "}":
|
|
brace_count -= 1
|
|
if brace_count == 0:
|
|
return i + 1
|
|
return len(json_string)
|
|
|
|
|
|
def extract_ui_json_from_text(content_str: str) -> tuple[str, None]:
|
|
"""Extract the UI JSON from the text.
|
|
|
|
Args:
|
|
text: The text to extract the UI JSON from.
|
|
|
|
Returns:
|
|
The UI JSON.
|
|
"""
|
|
text_content, json_string = content_str.split("---a2ui_JSON---", 1)
|
|
json_data = None
|
|
if json_string.strip():
|
|
try:
|
|
# Clean JSON string (remove markdown code blocks if present)
|
|
json_string_cleaned = (
|
|
json_string.strip().lstrip("```json").rstrip("```").strip()
|
|
)
|
|
|
|
# Find the end of JSON array/object by matching brackets/braces
|
|
json_end = _find_json_end(json_string_cleaned)
|
|
json_string_final = json_string_cleaned[:json_end].strip()
|
|
json_data = json.loads(json_string_final)
|
|
except json.JSONDecodeError as e:
|
|
logger.error("Failed to parse UI JSON: %s", e)
|
|
# On error, keep the JSON as text content
|
|
return content_str, None
|
|
return text_content, json_data
|
|
|
|
|
|
def check_a2ui_json_in_message(
|
|
part: Part,
|
|
is_final: bool,
|
|
) -> tuple[bool, str | None]:
|
|
"""Check if the message contains A2UI JSON.
|
|
|
|
Args:
|
|
message: The message to check.
|
|
|
|
Returns:
|
|
A tuple containing a boolean indicating if A2UI JSON is found and
|
|
the A2UI JSON string if found.
|
|
"""
|
|
# for the case when ReActAgent max iters is reached, the message will be
|
|
# the last complete message, with text message.
|
|
if (
|
|
isinstance(part.root, TextPart)
|
|
and "---a2ui_JSON---" in part.root.text
|
|
and is_final
|
|
):
|
|
logger.info(
|
|
"--- Found A2UI JSON in the message: %s ---",
|
|
part.root.text,
|
|
)
|
|
return True, part.root.text
|
|
|
|
# for the case when ReActAgent max iters is not reached, if it contains
|
|
# tool use block with name "generate_response" and type "tool_use", and
|
|
# the response_with_a2ui contains "---a2ui_JSON---", then return True,
|
|
# response_with_a2ui.
|
|
if (
|
|
isinstance(part.root, DataPart)
|
|
and part.root.data.get("name") == "generate_response"
|
|
and part.root.data.get("type") == "tool_use"
|
|
and not is_final
|
|
):
|
|
input_data = part.root.data.get("input")
|
|
if input_data and isinstance(input_data, dict):
|
|
response_with_a2ui = input_data.get("response_with_a2ui")
|
|
if response_with_a2ui and "---a2ui_JSON---" in response_with_a2ui:
|
|
return True, response_with_a2ui
|
|
return False, None
|
|
|
|
|
|
def post_process_a2a_message_for_ui(
|
|
message: Message,
|
|
) -> Message:
|
|
"""Post-process the transferred A2A message.
|
|
|
|
Args:
|
|
message: The transferred A2A message.
|
|
|
|
Returns:
|
|
The post-processed A2A message.
|
|
"""
|
|
new_parts = []
|
|
# pylint: disable=too-many-nested-blocks
|
|
for part in message.parts:
|
|
# Check if it's a text block and contains the A2UI JSON marker
|
|
if isinstance(part.root, TextPart):
|
|
text_content_str = part.root.text
|
|
if "---a2ui_JSON---" in text_content_str:
|
|
# Extract and process A2UI JSON
|
|
text_content, json_data = extract_ui_json_from_text(
|
|
text_content_str,
|
|
)
|
|
if json_data:
|
|
# Replace the part with a TextPart and multiple DataParts
|
|
# with the same metadata for a2ui
|
|
try:
|
|
new_parts.append(
|
|
Part(
|
|
root=TextPart(
|
|
text=text_content,
|
|
),
|
|
),
|
|
)
|
|
|
|
for item in json_data:
|
|
new_parts.append(
|
|
Part(
|
|
root=DataPart(
|
|
data=item,
|
|
metadata={
|
|
MIME_TYPE_KEY: A2UI_MIME_TYPE,
|
|
},
|
|
),
|
|
),
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Error processing a2ui JSON parts: %s",
|
|
e,
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
else:
|
|
# If JSON extraction failed, keep the original text block
|
|
new_parts.append(part)
|
|
else:
|
|
# Keep the original text block if it doesn't contain the marker
|
|
new_parts.append(part)
|
|
else:
|
|
# For non-text parts, keep the original logic
|
|
new_parts.append(part)
|
|
|
|
message.parts = new_parts
|
|
return message
|