Files
tw2/examples/agent/realtime_voice_agent/run_server.py
codex-bot a64378956a
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
chore: initialize sandbox and overwrite remote content
2026-03-02 22:32:27 +08:00

189 lines
5.6 KiB
Python

# -*- coding: utf-8 -*-
"""A test server"""
import asyncio
import os
import traceback
from pathlib import Path
import uvicorn
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
from agentscope import logger
from agentscope.agent import RealtimeAgent
from agentscope.realtime import (
DashScopeRealtimeModel,
GeminiRealtimeModel,
OpenAIRealtimeModel,
ClientEvents,
ServerEvents,
ClientEventType,
)
from agentscope.tool import (
Toolkit,
execute_python_code,
execute_shell_command,
view_text_file,
)
app = FastAPI()
@app.get("/")
async def get() -> FileResponse:
"""Serve the HTML test page."""
html_path = Path(__file__).parent / "chatbot.html"
return FileResponse(html_path)
@app.get("/api/check-models")
async def check_models() -> dict:
"""Check which model API keys are available in environment variables."""
return {
"dashscope": bool(os.getenv("DASHSCOPE_API_KEY")),
"gemini": bool(os.getenv("GEMINI_API_KEY")),
"openai": bool(os.getenv("OPENAI_API_KEY")),
}
async def frontend_receive(
websocket: WebSocket,
frontend_queue: asyncio.Queue,
) -> None:
"""Forward the message received from the agent to the frontend."""
try:
while True:
msg: ServerEvents.EventBase = await frontend_queue.get()
# Send the message as JSON
await websocket.send_json(msg.model_dump())
except Exception as e:
print(f"[ERROR] frontend_receive error: {e}")
traceback.print_exc()
@app.websocket("/ws/{user_id}/{session_id}")
async def single_agent_endpoint(
websocket: WebSocket,
user_id: str,
session_id: str,
) -> None:
"""WebSocket endpoint for a single realtime agent."""
try:
await websocket.accept()
logger.info(
"Connected to WebSocket: user_id=%s, session_id=%s",
user_id,
session_id,
)
# Create the queue to forward messages to the frontend
frontend_queue = asyncio.Queue()
asyncio.create_task(
frontend_receive(websocket, frontend_queue),
)
# Create the realtime agent
agent = None
while True:
# Handle the incoming messages from the frontend
# i.e. ClientEvents
data = await websocket.receive_json()
client_event = ClientEvents.from_json(data)
if isinstance(
client_event,
ClientEvents.ClientSessionCreateEvent,
):
# Create the agent by the given session arguments
instructions = client_event.config.get(
"instructions",
"You're a helpful assistant.",
)
agent_name = client_event.config.get("agent_name", "Friday")
model_provider = client_event.config.get(
"model_provider",
"dashscope",
)
sys_prompt = instructions
# Create toolkit with tools for models that support them
toolkit = None
if model_provider in ["gemini", "openai"]:
toolkit = Toolkit()
toolkit.register_tool_function(execute_python_code)
toolkit.register_tool_function(execute_shell_command)
toolkit.register_tool_function(view_text_file)
# Create the appropriate model based on provider
if model_provider == "dashscope":
model = DashScopeRealtimeModel(
model_name="qwen3-omni-flash-realtime",
api_key=os.getenv("DASHSCOPE_API_KEY"),
)
elif model_provider == "gemini":
model = GeminiRealtimeModel(
model_name=(
"gemini-2.5-flash-native-audio-preview-09-2025"
),
api_key=os.getenv("GEMINI_API_KEY"),
)
elif model_provider == "openai":
model = OpenAIRealtimeModel(
model_name="gpt-4o-realtime-preview",
api_key=os.getenv("OPENAI_API_KEY"),
)
else:
raise ValueError(
f"Unsupported model provider: {model_provider}",
)
# Create the agent
agent = RealtimeAgent(
name=agent_name,
sys_prompt=sys_prompt,
model=model,
toolkit=toolkit,
)
await agent.start(frontend_queue)
# Send session_created event to frontend
await websocket.send_json(
ServerEvents.ServerSessionCreatedEvent(
session_id=session_id,
).model_dump(),
)
print(
f"Session created successfully: {session_id}",
)
elif client_event.type == ClientEventType.CLIENT_SESSION_END:
# End the session with the agent
if agent:
await agent.stop()
agent = None
else:
await agent.handle_input(client_event)
except Exception as e:
print(f"[ERROR] WebSocket endpoint error: {e}")
traceback.print_exc()
raise
if __name__ == "__main__":
uvicorn.run(
"run_server:app",
host="localhost",
port=8000,
reload=True,
log_level="info",
)