Files
geo-setp/agent.py
2026-04-13 16:29:27 +08:00

250 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import math
import os
from dotenv import load_dotenv
from openai import AsyncOpenAI
from pydantic_ai import Agent
from pydantic_ai.mcp import MCPServerHTTP, MCPServerSSE, MCPServerStreamableHTTP
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from schemas import RoutePlanRequest, RoutePlanResult
load_dotenv()
class ConfigurationError(RuntimeError):
pass
class GuardrailError(RuntimeError):
pass
def _required_env(name: str) -> str:
value = os.getenv(name, "").strip()
if not value:
raise ConfigurationError(f"Missing required environment variable: {name}")
return value
def _env_positive_int(name: str, default: int) -> int:
raw_value = os.getenv(name, str(default)).strip()
try:
parsed = int(raw_value)
except ValueError as exc:
raise ConfigurationError(f"Environment variable {name} must be an integer") from exc
if parsed <= 0:
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
return parsed
def _env_positive_float(name: str, default: float) -> float:
raw_value = os.getenv(name, str(default)).strip()
try:
parsed = float(raw_value)
except ValueError as exc:
raise ConfigurationError(f"Environment variable {name} must be a number") from exc
if parsed <= 0:
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
return parsed
def _build_model() -> OpenAIModel:
openai_client = AsyncOpenAI(
base_url=_required_env("ARK_BASE_URL"),
api_key=_required_env("ARK_API_KEY"),
timeout=_env_positive_float("ARK_REQUEST_TIMEOUT_SECONDS", 120.0),
)
provider = OpenAIProvider(
openai_client=openai_client,
)
return OpenAIModel(
_required_env("ARK_MODEL"),
provider=provider,
)
def _build_mcp_headers() -> dict[str, str] | None:
header_name = os.getenv("AMAP_MCP_AUTH_HEADER_NAME", "").strip()
header_value = os.getenv("AMAP_MCP_AUTH_HEADER_VALUE", "").strip()
if header_name and header_value:
return {header_name: header_value}
if header_name or header_value:
raise ConfigurationError(
"AMAP_MCP_AUTH_HEADER_NAME and AMAP_MCP_AUTH_HEADER_VALUE must be set together"
)
return None
def _build_mcp_server() -> MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP:
transport = os.getenv("AMAP_MCP_TRANSPORT", "streamable_http").strip().lower()
url = _required_env("AMAP_MCP_URL")
headers = _build_mcp_headers()
timeout = _env_positive_float("AMAP_MCP_TIMEOUT_SECONDS", 20.0)
read_timeout = _env_positive_float("AMAP_MCP_READ_TIMEOUT_SECONDS", 60.0)
if transport == "sse":
return MCPServerSSE(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
if transport == "http":
return MCPServerHTTP(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
if transport == "streamable_http":
return MCPServerStreamableHTTP(
url=url,
headers=headers,
timeout=timeout,
read_timeout=read_timeout,
)
raise ConfigurationError(
"AMAP_MCP_TRANSPORT must be one of: streamable_http, http, sse"
)
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
你是一个无状态的多目标地理路线规划 Agent。你的任务是使用高德地图 MCP 工具完成多目标点最优路线规划,并返回强类型结构化结果。\
你必须显式调用地图工具不得编造坐标、POI、距离、时长或 deep link。
你的工作流必须严格遵守以下规则:
1. 先解析输入,明确起点模式、终点、途经点、优化策略和输出需求。
2. 对终点和每个途经点优先使用 maps_geo 解析地址;当命中不准或为空时,使用 maps_text_search 补齐 POI必要时使用 maps_search_detail 校验。
3. 当起点模式为 fixed 时,对起点也做同样解析。
4. 当起点模式为 current_location 时,不得伪造起点坐标;如果缺少实时定位坐标,只能比较途经点内部顺序,并明确说明真实最优路线会受当前定位影响。
5. 这套工具没有单步多点最优路径工具,因此你必须自己生成候选途经点顺序。
6. 对每个候选顺序,逐段调用 maps_direction_driving 计算距离与时长,并汇总为候选路线。
7. 按 route_strategy 选择最优路线:
- shortest_distance总里程优先总时长次优
- fastest_time总时长优先总里程次优
- balanced优先选择明显不劣的 Pareto 优势路线;若出现里程更短但时间略长的情况,默认优先里程更短并说明原因
8. 如需 deep link
- 若目标是稳定导入整组点位,可使用 maps_schema_personal_map但必须说明这是点位导入型链接
- 若目标是"我的位置"出发的即时导航,应输出 route plan 类 deep link并说明不同平台协议不同
9. 最终结果必须包含resolved_origin如适用、resolved_destination、resolved_stops、candidates、best_route、deep_links、summary、warnings。
10. 如果任何地址无法可靠解析、任何 POI 无法获取、或任何路线比较存在信息缺失,必须如实说明,不得猜测。
11. 你的输入会以 RoutePlanRequest JSON 形式提供。你必须基于该 JSON 进行规划,不能擅自补造字段。
12. 在返回结果前,必须确认 best_route 来自 candidates 之一,且每一段 distance 和 duration 都来自工具调用。
你的底层返回必须是结构化数据,不是 HTML。只有在用户明确要求页面展示时才在结构化结果基础上额外生成 HTML。
"""
def _configured_max_permutations() -> int:
return _env_positive_int("ROUTE_MAX_PERMUTATIONS", 20)
def _candidate_count(stop_count: int) -> int:
return math.factorial(stop_count)
def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest:
configured_limit = _configured_max_permutations()
effective_limit = request.max_permutations or configured_limit
if effective_limit > configured_limit:
raise GuardrailError(
"Requested max_permutations exceeds the configured service limit: "
f"requested={effective_limit}, configured={configured_limit}"
)
candidate_count = _candidate_count(len(request.stops))
if candidate_count > effective_limit:
raise GuardrailError(
"Candidate permutations exceed the configured limit: "
f"stops={len(request.stops)}, permutations={candidate_count}, limit={effective_limit}"
)
return request.model_copy(update={"max_permutations": effective_limit})
def _build_user_prompt(request: RoutePlanRequest) -> str:
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
return (
"请根据下面的 RoutePlanRequest JSON 执行多目标路线规划。\n"
"必须显式使用高德 MCP 工具完成地址解析、逐段驾车路线计算和 deep link 生成。\n"
"如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location"
"则不得伪造起点坐标。\n"
f"本次运行的候选顺序硬上限是 {request.max_permutations}\n\n"
"RoutePlanRequest JSON:\n"
f"{request_json}"
)
def _same_candidate(left_candidate, right_candidate) -> bool:
return (
left_candidate.full_order_labels == right_candidate.full_order_labels
and left_candidate.total_distance_m == right_candidate.total_distance_m
and left_candidate.total_duration_s == right_candidate.total_duration_s
and len(left_candidate.legs) == len(right_candidate.legs)
)
def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> RoutePlanResult:
if result.origin_mode != request.origin_mode:
raise GuardrailError("Agent output origin_mode does not match the request")
if result.resolved_destination.role != "destination":
raise GuardrailError("resolved_destination.role must be 'destination'")
if len(result.resolved_stops) != len(request.stops):
raise GuardrailError("resolved_stops count does not match the request")
if any(stop.role != "stop" for stop in result.resolved_stops):
raise GuardrailError("All resolved_stops entries must have role='stop'")
if request.origin_mode == "fixed":
if result.resolved_origin is None:
raise GuardrailError("resolved_origin is required when origin_mode='fixed'")
if result.resolved_origin.role != "origin":
raise GuardrailError("resolved_origin.role must be 'origin'")
else:
if result.resolved_origin is not None:
raise GuardrailError("resolved_origin must be null when origin_mode='current_location'")
if result.success:
if not result.candidates:
raise GuardrailError("A successful result must include at least one candidate route")
if len(result.candidates) > (request.max_permutations or 0):
raise GuardrailError("Agent returned more candidates than allowed by max_permutations")
matching_candidate = next(
(
candidate
for candidate in result.candidates
if _same_candidate(candidate, result.best_route)
),
None,
)
if matching_candidate is None:
raise GuardrailError("best_route must be one of the candidates")
return result
def create_geo_agent() -> Agent[None, RoutePlanResult]:
return Agent(
model=_build_model(),
toolsets=[_build_mcp_server()],
output_type=RoutePlanResult,
system_prompt=SYSTEM_PROMPT,
)
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def run_route_plan(request: RoutePlanRequest) -> RoutePlanResult:
"""Execute the route planning agent for a given request."""
prepared_request = _prepare_request(request)
agent = create_geo_agent()
async with agent:
result = await agent.run(_build_user_prompt(prepared_request))
return _validate_result(prepared_request, result.output)