init
This commit is contained in:
249
agent.py
Normal file
249
agent.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.mcp import MCPServerHTTP, MCPServerSSE, MCPServerStreamableHTTP
|
||||
from pydantic_ai.models.openai import OpenAIModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
|
||||
from schemas import RoutePlanRequest, RoutePlanResult
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ConfigurationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class GuardrailError(RuntimeError):
|
||||
pass
|
||||
|
||||
def _required_env(name: str) -> str:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
raise ConfigurationError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def _env_positive_int(name: str, default: int) -> int:
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = int(raw_value)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(f"Environment variable {name} must be an integer") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _env_positive_float(name: str, default: float) -> float:
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = float(raw_value)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(f"Environment variable {name} must be a number") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _build_model() -> OpenAIModel:
|
||||
openai_client = AsyncOpenAI(
|
||||
base_url=_required_env("ARK_BASE_URL"),
|
||||
api_key=_required_env("ARK_API_KEY"),
|
||||
timeout=_env_positive_float("ARK_REQUEST_TIMEOUT_SECONDS", 120.0),
|
||||
)
|
||||
provider = OpenAIProvider(
|
||||
openai_client=openai_client,
|
||||
)
|
||||
return OpenAIModel(
|
||||
_required_env("ARK_MODEL"),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def _build_mcp_headers() -> dict[str, str] | None:
|
||||
header_name = os.getenv("AMAP_MCP_AUTH_HEADER_NAME", "").strip()
|
||||
header_value = os.getenv("AMAP_MCP_AUTH_HEADER_VALUE", "").strip()
|
||||
|
||||
if header_name and header_value:
|
||||
return {header_name: header_value}
|
||||
if header_name or header_value:
|
||||
raise ConfigurationError(
|
||||
"AMAP_MCP_AUTH_HEADER_NAME and AMAP_MCP_AUTH_HEADER_VALUE must be set together"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_mcp_server() -> MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP:
|
||||
transport = os.getenv("AMAP_MCP_TRANSPORT", "streamable_http").strip().lower()
|
||||
url = _required_env("AMAP_MCP_URL")
|
||||
headers = _build_mcp_headers()
|
||||
timeout = _env_positive_float("AMAP_MCP_TIMEOUT_SECONDS", 20.0)
|
||||
read_timeout = _env_positive_float("AMAP_MCP_READ_TIMEOUT_SECONDS", 60.0)
|
||||
|
||||
if transport == "sse":
|
||||
return MCPServerSSE(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
|
||||
if transport == "http":
|
||||
return MCPServerHTTP(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
|
||||
if transport == "streamable_http":
|
||||
return MCPServerStreamableHTTP(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
read_timeout=read_timeout,
|
||||
)
|
||||
|
||||
raise ConfigurationError(
|
||||
"AMAP_MCP_TRANSPORT must be one of: streamable_http, http, sse"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
你是一个无状态的多目标地理路线规划 Agent。你的任务是使用高德地图 MCP 工具完成多目标点最优路线规划,并返回强类型结构化结果。\
|
||||
你必须显式调用地图工具,不得编造坐标、POI、距离、时长或 deep link。
|
||||
|
||||
你的工作流必须严格遵守以下规则:
|
||||
|
||||
1. 先解析输入,明确起点模式、终点、途经点、优化策略和输出需求。
|
||||
2. 对终点和每个途经点优先使用 maps_geo 解析地址;当命中不准或为空时,使用 maps_text_search 补齐 POI;必要时使用 maps_search_detail 校验。
|
||||
3. 当起点模式为 fixed 时,对起点也做同样解析。
|
||||
4. 当起点模式为 current_location 时,不得伪造起点坐标;如果缺少实时定位坐标,只能比较途经点内部顺序,并明确说明真实最优路线会受当前定位影响。
|
||||
5. 这套工具没有单步多点最优路径工具,因此你必须自己生成候选途经点顺序。
|
||||
6. 对每个候选顺序,逐段调用 maps_direction_driving 计算距离与时长,并汇总为候选路线。
|
||||
7. 按 route_strategy 选择最优路线:
|
||||
- shortest_distance:总里程优先,总时长次优
|
||||
- fastest_time:总时长优先,总里程次优
|
||||
- balanced:优先选择明显不劣的 Pareto 优势路线;若出现里程更短但时间略长的情况,默认优先里程更短并说明原因
|
||||
8. 如需 deep link:
|
||||
- 若目标是稳定导入整组点位,可使用 maps_schema_personal_map,但必须说明这是点位导入型链接
|
||||
- 若目标是"我的位置"出发的即时导航,应输出 route plan 类 deep link,并说明不同平台协议不同
|
||||
9. 最终结果必须包含:resolved_origin(如适用)、resolved_destination、resolved_stops、candidates、best_route、deep_links、summary、warnings。
|
||||
10. 如果任何地址无法可靠解析、任何 POI 无法获取、或任何路线比较存在信息缺失,必须如实说明,不得猜测。
|
||||
11. 你的输入会以 RoutePlanRequest JSON 形式提供。你必须基于该 JSON 进行规划,不能擅自补造字段。
|
||||
12. 在返回结果前,必须确认 best_route 来自 candidates 之一,且每一段 distance 和 duration 都来自工具调用。
|
||||
|
||||
你的底层返回必须是结构化数据,不是 HTML。只有在用户明确要求页面展示时,才在结构化结果基础上额外生成 HTML。
|
||||
"""
|
||||
|
||||
|
||||
def _configured_max_permutations() -> int:
|
||||
return _env_positive_int("ROUTE_MAX_PERMUTATIONS", 20)
|
||||
|
||||
|
||||
def _candidate_count(stop_count: int) -> int:
|
||||
return math.factorial(stop_count)
|
||||
|
||||
|
||||
def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest:
|
||||
configured_limit = _configured_max_permutations()
|
||||
effective_limit = request.max_permutations or configured_limit
|
||||
|
||||
if effective_limit > configured_limit:
|
||||
raise GuardrailError(
|
||||
"Requested max_permutations exceeds the configured service limit: "
|
||||
f"requested={effective_limit}, configured={configured_limit}"
|
||||
)
|
||||
|
||||
candidate_count = _candidate_count(len(request.stops))
|
||||
if candidate_count > effective_limit:
|
||||
raise GuardrailError(
|
||||
"Candidate permutations exceed the configured limit: "
|
||||
f"stops={len(request.stops)}, permutations={candidate_count}, limit={effective_limit}"
|
||||
)
|
||||
|
||||
return request.model_copy(update={"max_permutations": effective_limit})
|
||||
|
||||
|
||||
def _build_user_prompt(request: RoutePlanRequest) -> str:
|
||||
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
|
||||
return (
|
||||
"请根据下面的 RoutePlanRequest JSON 执行多目标路线规划。\n"
|
||||
"必须显式使用高德 MCP 工具完成地址解析、逐段驾车路线计算和 deep link 生成。\n"
|
||||
"如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location,"
|
||||
"则不得伪造起点坐标。\n"
|
||||
f"本次运行的候选顺序硬上限是 {request.max_permutations}。\n\n"
|
||||
"RoutePlanRequest JSON:\n"
|
||||
f"{request_json}"
|
||||
)
|
||||
|
||||
|
||||
def _same_candidate(left_candidate, right_candidate) -> bool:
|
||||
return (
|
||||
left_candidate.full_order_labels == right_candidate.full_order_labels
|
||||
and left_candidate.total_distance_m == right_candidate.total_distance_m
|
||||
and left_candidate.total_duration_s == right_candidate.total_duration_s
|
||||
and len(left_candidate.legs) == len(right_candidate.legs)
|
||||
)
|
||||
|
||||
|
||||
def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> RoutePlanResult:
|
||||
if result.origin_mode != request.origin_mode:
|
||||
raise GuardrailError("Agent output origin_mode does not match the request")
|
||||
|
||||
if result.resolved_destination.role != "destination":
|
||||
raise GuardrailError("resolved_destination.role must be 'destination'")
|
||||
|
||||
if len(result.resolved_stops) != len(request.stops):
|
||||
raise GuardrailError("resolved_stops count does not match the request")
|
||||
|
||||
if any(stop.role != "stop" for stop in result.resolved_stops):
|
||||
raise GuardrailError("All resolved_stops entries must have role='stop'")
|
||||
|
||||
if request.origin_mode == "fixed":
|
||||
if result.resolved_origin is None:
|
||||
raise GuardrailError("resolved_origin is required when origin_mode='fixed'")
|
||||
if result.resolved_origin.role != "origin":
|
||||
raise GuardrailError("resolved_origin.role must be 'origin'")
|
||||
else:
|
||||
if result.resolved_origin is not None:
|
||||
raise GuardrailError("resolved_origin must be null when origin_mode='current_location'")
|
||||
|
||||
if result.success:
|
||||
if not result.candidates:
|
||||
raise GuardrailError("A successful result must include at least one candidate route")
|
||||
|
||||
if len(result.candidates) > (request.max_permutations or 0):
|
||||
raise GuardrailError("Agent returned more candidates than allowed by max_permutations")
|
||||
|
||||
matching_candidate = next(
|
||||
(
|
||||
candidate
|
||||
for candidate in result.candidates
|
||||
if _same_candidate(candidate, result.best_route)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_candidate is None:
|
||||
raise GuardrailError("best_route must be one of the candidates")
|
||||
|
||||
return result
|
||||
|
||||
def create_geo_agent() -> Agent[None, RoutePlanResult]:
|
||||
return Agent(
|
||||
model=_build_model(),
|
||||
toolsets=[_build_mcp_server()],
|
||||
output_type=RoutePlanResult,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_route_plan(request: RoutePlanRequest) -> RoutePlanResult:
|
||||
"""Execute the route planning agent for a given request."""
|
||||
prepared_request = _prepare_request(request)
|
||||
agent = create_geo_agent()
|
||||
async with agent:
|
||||
result = await agent.run(_build_user_prompt(prepared_request))
|
||||
return _validate_result(prepared_request, result.output)
|
||||
Reference in New Issue
Block a user