import json import math import os import re from typing import Any from dotenv import load_dotenv from openai import AsyncOpenAI from pydantic_ai import Agent from pydantic_ai.exceptions import ModelRetry from pydantic_ai.mcp import MCPServerHTTP, MCPServerSSE, MCPServerStreamableHTTP from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider from schemas import DeepLinks, ResolvedPoint, RoutePlanRequest, RoutePlanResult load_dotenv() class ConfigurationError(RuntimeError): pass class GuardrailError(RuntimeError): pass def _tool_args(**kwargs: Any) -> dict[str, Any]: return {key: value for key, value in kwargs.items() if value is not None} def _required_env(name: str) -> str: value = os.getenv(name, "").strip() if not value: raise ConfigurationError(f"Missing required environment variable: {name}") return value def _env_positive_int(name: str, default: int) -> int: raw_value = os.getenv(name, str(default)).strip() try: parsed = int(raw_value) except ValueError as exc: raise ConfigurationError(f"Environment variable {name} must be an integer") from exc if parsed <= 0: raise ConfigurationError(f"Environment variable {name} must be greater than 0") return parsed def _env_positive_float(name: str, default: float) -> float: raw_value = os.getenv(name, str(default)).strip() try: parsed = float(raw_value) except ValueError as exc: raise ConfigurationError(f"Environment variable {name} must be a number") from exc if parsed <= 0: raise ConfigurationError(f"Environment variable {name} must be greater than 0") return parsed def _build_model() -> OpenAIModel: openai_client = AsyncOpenAI( base_url=_required_env("ARK_BASE_URL"), api_key=_required_env("ARK_API_KEY"), timeout=_env_positive_float("ARK_REQUEST_TIMEOUT_SECONDS", 120.0), ) provider = OpenAIProvider( openai_client=openai_client, ) return OpenAIModel( _required_env("ARK_MODEL"), provider=provider, ) def _build_mcp_headers() -> dict[str, str] | None: header_name = os.getenv("AMAP_MCP_AUTH_HEADER_NAME", "").strip() header_value = os.getenv("AMAP_MCP_AUTH_HEADER_VALUE", "").strip() if header_name and header_value: return {header_name: header_value} if header_name or header_value: raise ConfigurationError( "AMAP_MCP_AUTH_HEADER_NAME and AMAP_MCP_AUTH_HEADER_VALUE must be set together" ) return None def _build_mcp_server() -> MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP: transport = os.getenv("AMAP_MCP_TRANSPORT", "streamable_http").strip().lower() url = _required_env("AMAP_MCP_URL") headers = _build_mcp_headers() timeout = _env_positive_float("AMAP_MCP_TIMEOUT_SECONDS", 20.0) read_timeout = _env_positive_float("AMAP_MCP_READ_TIMEOUT_SECONDS", 60.0) if transport == "sse": return MCPServerSSE(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout) if transport == "http": return MCPServerHTTP(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout) if transport == "streamable_http": return MCPServerStreamableHTTP( url=url, headers=headers, timeout=timeout, read_timeout=read_timeout, ) raise ConfigurationError( "AMAP_MCP_TRANSPORT must be one of: streamable_http, http, sse" ) # --------------------------------------------------------------------------- # Agent # --------------------------------------------------------------------------- SYSTEM_PROMPT = """\ 你是一个无状态的多目标地理路线规划 Agent。你的任务是使用高德地图 MCP 工具完成多目标点最优路线规划,并返回强类型结构化结果。\ 你必须显式调用地图工具,不得编造坐标、POI、距离、时长或 deep link。 你的工作流必须严格遵守以下规则: 1. 先解析输入,明确起点模式、终点、途经点、优化策略和输出需求。 2. 对终点和每个途经点优先使用 maps_geo 解析地址;当命中不准或为空时,使用 maps_text_search 补齐 POI;必要时使用 maps_search_detail 校验。 3. 当起点模式为 fixed 时,对起点也做同样解析。 4. 当起点模式为 current_location 时,不得伪造起点坐标;如果缺少实时定位坐标,只能比较途经点内部顺序,并明确说明真实最优路线会受当前定位影响。 5. 这套工具没有单步多点最优路径工具,因此你必须自己生成候选途经点顺序。 6. 对每个候选顺序,逐段调用 maps_direction_driving 计算距离与时长,并汇总为候选路线。 7. 按 route_strategy 选择最优路线: - shortest_distance:总里程优先,总时长次优 - fastest_time:总时长优先,总里程次优 - balanced:优先选择明显不劣的 Pareto 优势路线;若出现里程更短但时间略长的情况,默认优先里程更短并说明原因 8. 如需 deep link: - 若目标是稳定导入整组点位,可使用 maps_schema_personal_map,但必须说明这是点位导入型链接 - 若目标是"我的位置"出发的即时导航,应输出 route plan 类 deep link,并说明不同平台协议不同 9. 最终结果必须包含:resolved_origin(如适用)、resolved_destination、resolved_stops、candidates、best_route、deep_links、summary、warnings。 10. 如果任何地址无法可靠解析、任何 POI 无法获取、或任何路线比较存在信息缺失,必须如实说明,不得猜测。 11. 你的输入会以 RoutePlanRequest JSON 形式提供。你必须基于该 JSON 进行规划,不能擅自补造字段。 12. 在返回结果前,必须确认 best_route 来自 candidates 之一,且每一段 distance 和 duration 都来自工具调用。 13. 如果输入中已经提供“预解析点位 JSON”,则这些点位是服务端已验证的唯一可信事实来源。此时不要再调用 maps_geo、maps_text_search、maps_search_detail,也不要重新挑选 POI。 14. 当预解析点位已提供时,你只需要基于这些点位生成候选顺序、调用 maps_direction_driving 逐段计算、挑选 best_route,并产出 summary 与 warnings。 15. 当预解析点位已提供时,deep_links 字段保持为 null,由服务端在结果通过校验后统一生成。 你的底层返回必须是结构化数据,不是 HTML。只有在用户明确要求页面展示时,才在结构化结果基础上额外生成 HTML。 """ _SUB_POI_TOKENS = { "停车场", "入口", "出口", "住院部", "门诊", "外科", "内科", "体检", "饭堂", "发热门诊", "楼", "中心", } def _normalize_text(value: str | None) -> str: if not value: return "" return re.sub(r"[\s\-_,,。.;;::/\\()()\[\]【】{}·]", "", value).casefold() def _parse_location(location: str) -> tuple[float, float]: try: lon_text, lat_text = [item.strip() for item in location.split(",", 1)] return float(lon_text), float(lat_text) except Exception as exc: raise GuardrailError(f"Invalid location value returned by maps service: {location}") from exc def _haversine_m(lon1: float, lat1: float, lon2: float, lat2: float) -> float: radius_m = 6371000 phi1 = math.radians(lat1) phi2 = math.radians(lat2) d_phi = math.radians(lat2 - lat1) d_lambda = math.radians(lon2 - lon1) a = math.sin(d_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2) ** 2 return 2 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1 - a)) def _score_poi_candidate( poi: dict[str, Any], *, query_text: str, address_text: str, raw_query: str, index: int, ) -> int: name = _normalize_text(str(poi.get("name", ""))) address = _normalize_text(str(poi.get("address", ""))) score = 0 if query_text and name == query_text: score += 120 elif query_text and query_text in name: score += 90 elif query_text and name and name in query_text: score += 45 if address_text and address == address_text: score += 80 elif address_text and address_text in address: score += 45 if poi.get("id"): score += 20 display_name = str(poi.get("name", "")) if any(token in display_name and token not in raw_query for token in _SUB_POI_TOKENS): score -= 25 score -= min(index, 10) return score def _select_precise_poi( pois: list[dict[str, Any]], *, input_name: str | None, input_address: str, ) -> dict[str, Any]: if not pois: raise GuardrailError(f"No POI candidates found for address: {input_address}") raw_query = input_name or input_address query_text = _normalize_text(raw_query) address_text = _normalize_text(input_address) ranked = [ ( _score_poi_candidate( poi, query_text=query_text, address_text=address_text, raw_query=raw_query, index=index, ), index, poi, ) for index, poi in enumerate(pois) ] ranked.sort(key=lambda item: (item[0], -item[1]), reverse=True) best_score, _, best_poi = ranked[0] if best_score < 80: raise GuardrailError( f"POI precision is insufficient for address: {input_address}; best score={best_score}" ) if len(ranked) > 1: second_score, _, second_poi = ranked[1] if second_poi.get("id") != best_poi.get("id") and second_score >= best_score - 10: raise GuardrailError( f"POI match is ambiguous for address: {input_address}; multiple close candidates found" ) poi_id = str(best_poi.get("id") or "").strip() if not poi_id: raise GuardrailError(f"POI ID is missing for address: {input_address}") return best_poi async def _call_mcp_tool(server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP, name: str, args: dict[str, Any]) -> Any: try: return await server.direct_call_tool(name, args) except ModelRetry as exc: if "Timed out while waiting" in str(exc): raise TimeoutError(str(exc)) from exc raise GuardrailError(f"MCP tool call failed for {name}: {exc}") from exc def _closest_geo_result( geo_results: list[dict[str, Any]], *, lon: float, lat: float, ) -> dict[str, Any] | None: closest_result: dict[str, Any] | None = None closest_distance: float | None = None for item in geo_results: location = str(item.get("location") or "").strip() if not location: continue candidate_lon, candidate_lat = _parse_location(location) distance = _haversine_m(lon, lat, candidate_lon, candidate_lat) if closest_distance is None or distance < closest_distance: closest_distance = distance closest_result = item if closest_distance is not None and closest_distance > 500: raise GuardrailError("Geocode and POI detail locations disagree beyond acceptable precision") return closest_result async def _resolve_point( server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP, *, role: str, input_name: str | None, input_address: str, city: str | None, ) -> ResolvedPoint: geo_response = await _call_mcp_tool( server, "maps_geo", _tool_args(address=input_address, city=city), ) text_search_response = await _call_mcp_tool( server, "maps_text_search", _tool_args(keywords=input_name or input_address, city=city), ) geo_results = list((geo_response or {}).get("results") or []) pois = list((text_search_response or {}).get("pois") or []) selected_poi = _select_precise_poi(pois, input_name=input_name, input_address=input_address) detail = await _call_mcp_tool( server, "maps_search_detail", {"id": selected_poi["id"]}, ) detail_location = str(detail.get("location") or "").strip() if not detail_location: raise GuardrailError(f"Resolved POI detail has no location for address: {input_address}") lon, lat = _parse_location(detail_location) closest_geo = _closest_geo_result(geo_results, lon=lon, lat=lat) if geo_results else None district = None if closest_geo is not None: district = str(closest_geo.get("district") or "").strip() or None resolved_name = str(detail.get("name") or selected_poi.get("name") or "").strip() poi_id = str(detail.get("id") or selected_poi.get("id") or "").strip() resolved_city = str(detail.get("city") or city or "").strip() or city if not resolved_name: raise GuardrailError(f"Resolved POI has no name for address: {input_address}") if not poi_id: raise GuardrailError(f"Resolved POI has no poi_id for address: {input_address}") return ResolvedPoint( role=role, # type: ignore[arg-type] input_name=input_name, input_address=input_address, resolved_name=resolved_name, city=resolved_city, district=district, location=detail_location, lon=lon, lat=lat, poi_id=poi_id, source="search_detail", confidence_note="Prevalidated exact POI with confirmed poi_id for deep link generation", ) async def _resolve_request_points( request: RoutePlanRequest, ) -> tuple[ResolvedPoint | None, list[ResolvedPoint], ResolvedPoint]: cache: dict[tuple[str | None, str, str | None], ResolvedPoint] = {} server = _build_mcp_server() async with server: async def resolve_cached(*, role: str, input_name: str | None, input_address: str, city: str | None) -> ResolvedPoint: cache_key = (input_name, input_address, city) cached = cache.get(cache_key) if cached is None: cached = await _resolve_point( server, role=role, input_name=input_name, input_address=input_address, city=city, ) cache[cache_key] = cached return cached.model_copy( update={ "role": role, "input_name": input_name, "input_address": input_address, } ) resolved_origin: ResolvedPoint | None = None if request.origin_mode == "fixed" and request.origin_address is not None: resolved_origin = await resolve_cached( role="origin", input_name=request.origin_name, input_address=request.origin_address, city=request.origin_city, ) resolved_stops = [ await resolve_cached( role="stop", input_name=stop.name, input_address=stop.address, city=stop.city, ) for stop in request.stops ] resolved_destination = await resolve_cached( role="destination", input_name=request.destination_name, input_address=request.destination_address, city=request.destination_city, ) return resolved_origin, resolved_stops, resolved_destination def _resolved_points_payload( *, resolved_origin: ResolvedPoint | None, resolved_stops: list[ResolvedPoint], resolved_destination: ResolvedPoint, ) -> dict[str, Any]: return { "resolved_origin": resolved_origin.model_dump(mode="json") if resolved_origin else None, "resolved_stops": [point.model_dump(mode="json") for point in resolved_stops], "resolved_destination": resolved_destination.model_dump(mode="json"), } def _point_label_candidates(point: ResolvedPoint) -> list[str]: labels = [point.resolved_name] if point.input_name: labels.append(point.input_name) labels.append(point.input_address) deduplicated: list[str] = [] for label in labels: if label not in deduplicated: deduplicated.append(label) return deduplicated def _ordered_points_for_best_route( request: RoutePlanRequest, result: RoutePlanResult, *, resolved_origin: ResolvedPoint | None, resolved_stops: list[ResolvedPoint], resolved_destination: ResolvedPoint, ) -> list[ResolvedPoint]: ordered_points: list[ResolvedPoint] = [] if request.origin_mode == "fixed" and resolved_origin is not None: ordered_points.append(resolved_origin) remaining_stops = resolved_stops.copy() for label in result.best_route.stop_order_labels: normalized_label = _normalize_text(label) match_index = next( ( index for index, point in enumerate(remaining_stops) if normalized_label in {_normalize_text(candidate) for candidate in _point_label_candidates(point)} ), None, ) if match_index is None: raise GuardrailError(f"Unable to map best_route stop label back to resolved stop: {label}") ordered_points.append(remaining_stops.pop(match_index)) ordered_points.append(resolved_destination) return ordered_points async def _build_required_deep_links( request: RoutePlanRequest, result: RoutePlanResult, *, resolved_origin: ResolvedPoint | None, resolved_stops: list[ResolvedPoint], resolved_destination: ResolvedPoint, ) -> DeepLinks: if not request.need_deep_link: raise GuardrailError("need_deep_link must be true because this service requires deep link output") ordered_points = _ordered_points_for_best_route( request, result, resolved_origin=resolved_origin, resolved_stops=resolved_stops, resolved_destination=resolved_destination, ) for point in ordered_points: if not point.poi_id: raise GuardrailError( f"Deep link generation requires poi_id for every point; missing poi_id for {point.input_address}" ) server = _build_mcp_server() async with server: personal_map = await _call_mcp_tool( server, "maps_schema_personal_map", { "orgName": "geo-agent", "lineList": [ { "title": request.task_name, "pointInfoList": [ { "name": point.resolved_name, "lon": point.lon, "lat": point.lat, "poiId": point.poi_id, } for point in ordered_points ], } ], }, ) if not isinstance(personal_map, str) or not personal_map.strip(): raise GuardrailError("Deep link generation failed: maps_schema_personal_map returned no usable URI") return DeepLinks(personal_map=personal_map.strip()) def _configured_max_permutations() -> int: return _env_positive_int("ROUTE_MAX_PERMUTATIONS", 20) def _candidate_count(stop_count: int) -> int: return math.factorial(stop_count) def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest: if not request.need_deep_link: raise GuardrailError("need_deep_link must be true because this service requires deep link output") configured_limit = _configured_max_permutations() effective_limit = request.max_permutations or configured_limit if effective_limit > configured_limit: raise GuardrailError( "Requested max_permutations exceeds the configured service limit: " f"requested={effective_limit}, configured={configured_limit}" ) candidate_count = _candidate_count(len(request.stops)) if candidate_count > effective_limit: raise GuardrailError( "Candidate permutations exceed the configured limit: " f"stops={len(request.stops)}, permutations={candidate_count}, limit={effective_limit}" ) return request.model_copy(update={"max_permutations": effective_limit}) def _build_user_prompt( request: RoutePlanRequest, *, resolved_origin: ResolvedPoint | None, resolved_stops: list[ResolvedPoint], resolved_destination: ResolvedPoint, ) -> str: request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2) resolved_points_json = json.dumps( _resolved_points_payload( resolved_origin=resolved_origin, resolved_stops=resolved_stops, resolved_destination=resolved_destination, ), ensure_ascii=False, indent=2, ) return ( "请根据下面的 RoutePlanRequest JSON 执行多目标路线规划。\n" "所有点位已经由服务端完成严格解析和 POI 校验,且这些点位是唯一可信输入。\n" "不要再调用 maps_geo、maps_text_search、maps_search_detail,也不要生成 deep link。\n" "你必须基于这些已解析点位,仅使用 maps_direction_driving 计算逐段路线,并选择 best_route。\n" "如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location,则不得伪造起点坐标。\n" f"本次运行的候选顺序硬上限是 {request.max_permutations}。\n\n" "RoutePlanRequest JSON:\n" f"{request_json}\n\n" "Pre-resolved Points JSON:\n" f"{resolved_points_json}" ) def _same_candidate(left_candidate, right_candidate) -> bool: return ( left_candidate.full_order_labels == right_candidate.full_order_labels and left_candidate.total_distance_m == right_candidate.total_distance_m and left_candidate.total_duration_s == right_candidate.total_duration_s and len(left_candidate.legs) == len(right_candidate.legs) ) def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> RoutePlanResult: if not result.success: raise GuardrailError("Route planning did not complete successfully") if result.origin_mode != request.origin_mode: raise GuardrailError("Agent output origin_mode does not match the request") if result.resolved_destination.role != "destination": raise GuardrailError("resolved_destination.role must be 'destination'") if len(result.resolved_stops) != len(request.stops): raise GuardrailError("resolved_stops count does not match the request") if any(stop.role != "stop" for stop in result.resolved_stops): raise GuardrailError("All resolved_stops entries must have role='stop'") if request.origin_mode == "fixed": if result.resolved_origin is None: raise GuardrailError("resolved_origin is required when origin_mode='fixed'") if result.resolved_origin.role != "origin": raise GuardrailError("resolved_origin.role must be 'origin'") else: if result.resolved_origin is not None: raise GuardrailError("resolved_origin must be null when origin_mode='current_location'") if not result.candidates: raise GuardrailError("A successful result must include at least one candidate route") if len(result.candidates) > (request.max_permutations or 0): raise GuardrailError("Agent returned more candidates than allowed by max_permutations") matching_candidate = next( ( candidate for candidate in result.candidates if _same_candidate(candidate, result.best_route) ), None, ) if matching_candidate is None: raise GuardrailError("best_route must be one of the candidates") if result.deep_links is None: raise GuardrailError("A successful result must include deep_links") if not any( [ result.deep_links.personal_map, result.deep_links.android_route_plan, result.deep_links.ios_route_plan, ] ): raise GuardrailError("A successful result must include at least one deep link") all_points = [*result.resolved_stops, result.resolved_destination] if result.resolved_origin is not None: all_points.append(result.resolved_origin) if any(not point.poi_id for point in all_points): raise GuardrailError("All resolved points must include poi_id on successful results") return result def create_geo_agent() -> Agent[None, RoutePlanResult]: return Agent( model=_build_model(), toolsets=[_build_mcp_server()], output_type=RoutePlanResult, system_prompt=SYSTEM_PROMPT, ) # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- async def run_route_plan(request: RoutePlanRequest) -> RoutePlanResult: """Execute the route planning agent for a given request.""" prepared_request = _prepare_request(request) resolved_origin, resolved_stops, resolved_destination = await _resolve_request_points(prepared_request) agent = create_geo_agent() async with agent: result = await agent.run( _build_user_prompt( prepared_request, resolved_origin=resolved_origin, resolved_stops=resolved_stops, resolved_destination=resolved_destination, ) ) output = result.output.model_copy( update={ "resolved_origin": resolved_origin, "resolved_stops": resolved_stops, "resolved_destination": resolved_destination, } ) output = output.model_copy( update={ "deep_links": await _build_required_deep_links( prepared_request, output, resolved_origin=resolved_origin, resolved_stops=resolved_stops, resolved_destination=resolved_destination, ) } ) return _validate_result(prepared_request, output)