Files
geo-setp/agent/route_plan.py
2026-04-15 14:05:33 +08:00

732 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import math
import os
import re
from typing import Any
from dotenv import load_dotenv
from openai import AsyncOpenAI
from pydantic_ai import Agent
from pydantic_ai.exceptions import ModelRetry
from pydantic_ai.mcp import MCPServerHTTP, MCPServerSSE, MCPServerStreamableHTTP
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from schemas import DeepLinks, ResolvedPoint, RoutePlanRequest, RoutePlanResult
load_dotenv()
class ConfigurationError(RuntimeError):
pass
class GuardrailError(RuntimeError):
pass
def _tool_args(**kwargs: Any) -> dict[str, Any]:
return {key: value for key, value in kwargs.items() if value is not None}
def _required_env(name: str) -> str:
value = os.getenv(name, "").strip()
if not value:
raise ConfigurationError(f"Missing required environment variable: {name}")
return value
def _env_positive_int(name: str, default: int) -> int:
raw_value = os.getenv(name, str(default)).strip()
try:
parsed = int(raw_value)
except ValueError as exc:
raise ConfigurationError(f"Environment variable {name} must be an integer") from exc
if parsed <= 0:
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
return parsed
def _env_positive_float(name: str, default: float) -> float:
raw_value = os.getenv(name, str(default)).strip()
try:
parsed = float(raw_value)
except ValueError as exc:
raise ConfigurationError(f"Environment variable {name} must be a number") from exc
if parsed <= 0:
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
return parsed
def _build_model() -> OpenAIModel:
openai_client = AsyncOpenAI(
base_url=_required_env("ARK_BASE_URL"),
api_key=_required_env("ARK_API_KEY"),
timeout=_env_positive_float("ARK_REQUEST_TIMEOUT_SECONDS", 120.0),
)
provider = OpenAIProvider(
openai_client=openai_client,
)
return OpenAIModel(
_required_env("ARK_MODEL"),
provider=provider,
)
def _build_mcp_headers() -> dict[str, str] | None:
header_name = os.getenv("AMAP_MCP_AUTH_HEADER_NAME", "").strip()
header_value = os.getenv("AMAP_MCP_AUTH_HEADER_VALUE", "").strip()
if header_name and header_value:
return {header_name: header_value}
if header_name or header_value:
raise ConfigurationError(
"AMAP_MCP_AUTH_HEADER_NAME and AMAP_MCP_AUTH_HEADER_VALUE must be set together"
)
return None
def _build_mcp_server() -> MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP:
transport = os.getenv("AMAP_MCP_TRANSPORT", "streamable_http").strip().lower()
url = _required_env("AMAP_MCP_URL")
headers = _build_mcp_headers()
timeout = _env_positive_float("AMAP_MCP_TIMEOUT_SECONDS", 20.0)
read_timeout = _env_positive_float("AMAP_MCP_READ_TIMEOUT_SECONDS", 60.0)
if transport == "sse":
return MCPServerSSE(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
if transport == "http":
return MCPServerHTTP(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
if transport == "streamable_http":
return MCPServerStreamableHTTP(
url=url,
headers=headers,
timeout=timeout,
read_timeout=read_timeout,
)
raise ConfigurationError(
"AMAP_MCP_TRANSPORT must be one of: streamable_http, http, sse"
)
# ---------------------------------------------------------------------------
# Agent
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
你是一个无状态的多目标地理路线规划 Agent。你的任务是使用高德地图 MCP 工具完成多目标点最优路线规划,并返回强类型结构化结果。\
你必须显式调用地图工具不得编造坐标、POI、距离、时长或 deep link。
你的工作流必须严格遵守以下规则:
1. 先解析输入,明确起点模式、终点、途经点、优化策略和输出需求。
2. 对终点和每个途经点优先使用 maps_geo 解析地址;当命中不准或为空时,使用 maps_text_search 补齐 POI必要时使用 maps_search_detail 校验。
3. 当起点模式为 fixed 时,对起点也做同样解析。
4. 当起点模式为 current_location 时,不得伪造起点坐标;如果缺少实时定位坐标,只能比较途经点内部顺序,并明确说明真实最优路线会受当前定位影响。
5. 这套工具没有单步多点最优路径工具,因此你必须自己生成候选途经点顺序。
6. 对每个候选顺序,逐段调用 maps_direction_driving 计算距离与时长,并汇总为候选路线。
7. 按 route_strategy 选择最优路线:
- shortest_distance总里程优先总时长次优
- fastest_time总时长优先总里程次优
- balanced优先选择明显不劣的 Pareto 优势路线;若出现里程更短但时间略长的情况,默认优先里程更短并说明原因
8. 如需 deep link
- 若目标是稳定导入整组点位,可使用 maps_schema_personal_map但必须说明这是点位导入型链接
- 若目标是"我的位置"出发的即时导航,应输出 route plan 类 deep link并说明不同平台协议不同
9. 最终结果必须包含resolved_origin如适用、resolved_destination、resolved_stops、candidates、best_route、deep_links、summary、warnings。
10. 如果任何地址无法可靠解析、任何 POI 无法获取、或任何路线比较存在信息缺失,必须如实说明,不得猜测。
11. 你的输入会以 RoutePlanRequest JSON 形式提供。你必须基于该 JSON 进行规划,不能擅自补造字段。
12. 在返回结果前,必须确认 best_route 来自 candidates 之一,且每一段 distance 和 duration 都来自工具调用。
13. 如果输入中已经提供“预解析点位 JSON”则这些点位是服务端已验证的唯一可信事实来源。此时不要再调用 maps_geo、maps_text_search、maps_search_detail也不要重新挑选 POI。
14. 当预解析点位已提供时,你只需要基于这些点位生成候选顺序、调用 maps_direction_driving 逐段计算、挑选 best_route并产出 summary 与 warnings。
15. 当预解析点位已提供时deep_links 字段保持为 null由服务端在结果通过校验后统一生成。
你的底层返回必须是结构化数据,不是 HTML。只有在用户明确要求页面展示时才在结构化结果基础上额外生成 HTML。
"""
_SUB_POI_TOKENS = {
"停车场",
"入口",
"出口",
"住院部",
"门诊",
"外科",
"内科",
"体检",
"饭堂",
"发热门诊",
"",
"中心",
}
def _normalize_text(value: str | None) -> str:
if not value:
return ""
return re.sub(r"[\s\-_,,。.;:/\\()\[\]【】{}·]", "", value).casefold()
def _parse_location(location: str) -> tuple[float, float]:
try:
lon_text, lat_text = [item.strip() for item in location.split(",", 1)]
return float(lon_text), float(lat_text)
except Exception as exc:
raise GuardrailError(f"Invalid location value returned by maps service: {location}") from exc
def _haversine_m(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
radius_m = 6371000
phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
d_phi = math.radians(lat2 - lat1)
d_lambda = math.radians(lon2 - lon1)
a = math.sin(d_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2) ** 2
return 2 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1 - a))
def _score_poi_candidate(
poi: dict[str, Any],
*,
query_text: str,
address_text: str,
raw_query: str,
index: int,
) -> int:
name = _normalize_text(str(poi.get("name", "")))
address = _normalize_text(str(poi.get("address", "")))
score = 0
if query_text and name == query_text:
score += 120
elif query_text and query_text in name:
score += 90
elif query_text and name and name in query_text:
score += 45
if address_text and address == address_text:
score += 80
elif address_text and address_text in address:
score += 45
if poi.get("id"):
score += 20
display_name = str(poi.get("name", ""))
if any(token in display_name and token not in raw_query for token in _SUB_POI_TOKENS):
score -= 25
score -= min(index, 10)
return score
def _select_precise_poi(
pois: list[dict[str, Any]],
*,
input_name: str | None,
input_address: str,
) -> dict[str, Any]:
if not pois:
raise GuardrailError(f"No POI candidates found for address: {input_address}")
raw_query = input_name or input_address
query_text = _normalize_text(raw_query)
address_text = _normalize_text(input_address)
ranked = [
(
_score_poi_candidate(
poi,
query_text=query_text,
address_text=address_text,
raw_query=raw_query,
index=index,
),
index,
poi,
)
for index, poi in enumerate(pois)
]
ranked.sort(key=lambda item: (item[0], -item[1]), reverse=True)
best_score, _, best_poi = ranked[0]
if best_score < 80:
raise GuardrailError(
f"POI precision is insufficient for address: {input_address}; best score={best_score}"
)
if len(ranked) > 1:
second_score, _, second_poi = ranked[1]
if second_poi.get("id") != best_poi.get("id") and second_score >= best_score - 10:
raise GuardrailError(
f"POI match is ambiguous for address: {input_address}; multiple close candidates found"
)
poi_id = str(best_poi.get("id") or "").strip()
if not poi_id:
raise GuardrailError(f"POI ID is missing for address: {input_address}")
return best_poi
async def _call_mcp_tool(
server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP,
name: str,
args: dict[str, Any],
) -> Any:
try:
return await server.direct_call_tool(name, args)
except ModelRetry as exc:
if "Timed out while waiting" in str(exc):
raise TimeoutError(str(exc)) from exc
raise GuardrailError(f"MCP tool call failed for {name}: {exc}") from exc
def _closest_geo_result(
geo_results: list[dict[str, Any]],
*,
lon: float,
lat: float,
) -> dict[str, Any] | None:
closest_result: dict[str, Any] | None = None
closest_distance: float | None = None
for item in geo_results:
location = str(item.get("location") or "").strip()
if not location:
continue
candidate_lon, candidate_lat = _parse_location(location)
distance = _haversine_m(lon, lat, candidate_lon, candidate_lat)
if closest_distance is None or distance < closest_distance:
closest_distance = distance
closest_result = item
if closest_distance is not None and closest_distance > 500:
raise GuardrailError("Geocode and POI detail locations disagree beyond acceptable precision")
return closest_result
async def _resolve_point(
server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP,
*,
role: str,
input_name: str | None,
input_address: str,
city: str | None,
) -> ResolvedPoint:
geo_response = await _call_mcp_tool(
server,
"maps_geo",
_tool_args(address=input_address, city=city),
)
text_search_response = await _call_mcp_tool(
server,
"maps_text_search",
_tool_args(keywords=input_name or input_address, city=city),
)
geo_results = list((geo_response or {}).get("results") or [])
pois = list((text_search_response or {}).get("pois") or [])
selected_poi = _select_precise_poi(pois, input_name=input_name, input_address=input_address)
detail = await _call_mcp_tool(
server,
"maps_search_detail",
{"id": selected_poi["id"]},
)
detail_location = str(detail.get("location") or "").strip()
if not detail_location:
raise GuardrailError(f"Resolved POI detail has no location for address: {input_address}")
lon, lat = _parse_location(detail_location)
closest_geo = _closest_geo_result(geo_results, lon=lon, lat=lat) if geo_results else None
district = None
if closest_geo is not None:
district = str(closest_geo.get("district") or "").strip() or None
resolved_name = str(detail.get("name") or selected_poi.get("name") or "").strip()
poi_id = str(detail.get("id") or selected_poi.get("id") or "").strip()
resolved_city = str(detail.get("city") or city or "").strip() or city
if not resolved_name:
raise GuardrailError(f"Resolved POI has no name for address: {input_address}")
if not poi_id:
raise GuardrailError(f"Resolved POI has no poi_id for address: {input_address}")
return ResolvedPoint(
role=role, # type: ignore[arg-type]
input_name=input_name,
input_address=input_address,
resolved_name=resolved_name,
city=resolved_city,
district=district,
location=detail_location,
lon=lon,
lat=lat,
poi_id=poi_id,
source="search_detail",
confidence_note="Prevalidated exact POI with confirmed poi_id for deep link generation",
)
async def _resolve_request_points(
request: RoutePlanRequest,
) -> tuple[ResolvedPoint | None, list[ResolvedPoint], ResolvedPoint]:
cache: dict[tuple[str | None, str, str | None], ResolvedPoint] = {}
server = _build_mcp_server()
async with server:
async def resolve_cached(
*,
role: str,
input_name: str | None,
input_address: str,
city: str | None,
) -> ResolvedPoint:
cache_key = (input_name, input_address, city)
cached = cache.get(cache_key)
if cached is None:
cached = await _resolve_point(
server,
role=role,
input_name=input_name,
input_address=input_address,
city=city,
)
cache[cache_key] = cached
return cached.model_copy(
update={
"role": role,
"input_name": input_name,
"input_address": input_address,
}
)
resolved_origin: ResolvedPoint | None = None
if request.origin_mode == "fixed" and request.origin_address is not None:
resolved_origin = await resolve_cached(
role="origin",
input_name=request.origin_name,
input_address=request.origin_address,
city=request.origin_city,
)
resolved_stops = [
await resolve_cached(
role="stop",
input_name=stop.name,
input_address=stop.address,
city=stop.city,
)
for stop in request.stops
]
resolved_destination = await resolve_cached(
role="destination",
input_name=request.destination_name,
input_address=request.destination_address,
city=request.destination_city,
)
return resolved_origin, resolved_stops, resolved_destination
def _resolved_points_payload(
*,
resolved_origin: ResolvedPoint | None,
resolved_stops: list[ResolvedPoint],
resolved_destination: ResolvedPoint,
) -> dict[str, Any]:
return {
"resolved_origin": resolved_origin.model_dump(mode="json") if resolved_origin else None,
"resolved_stops": [point.model_dump(mode="json") for point in resolved_stops],
"resolved_destination": resolved_destination.model_dump(mode="json"),
}
def _point_label_candidates(point: ResolvedPoint) -> list[str]:
labels = [point.resolved_name]
if point.input_name:
labels.append(point.input_name)
labels.append(point.input_address)
deduplicated: list[str] = []
for label in labels:
if label not in deduplicated:
deduplicated.append(label)
return deduplicated
def _ordered_points_for_best_route(
request: RoutePlanRequest,
result: RoutePlanResult,
*,
resolved_origin: ResolvedPoint | None,
resolved_stops: list[ResolvedPoint],
resolved_destination: ResolvedPoint,
) -> list[ResolvedPoint]:
ordered_points: list[ResolvedPoint] = []
if request.origin_mode == "fixed" and resolved_origin is not None:
ordered_points.append(resolved_origin)
remaining_stops = resolved_stops.copy()
for label in result.best_route.stop_order_labels:
normalized_label = _normalize_text(label)
match_index = next(
(
index
for index, point in enumerate(remaining_stops)
if normalized_label in {_normalize_text(candidate) for candidate in _point_label_candidates(point)}
),
None,
)
if match_index is None:
raise GuardrailError(f"Unable to map best_route stop label back to resolved stop: {label}")
ordered_points.append(remaining_stops.pop(match_index))
ordered_points.append(resolved_destination)
return ordered_points
async def _build_required_deep_links(
request: RoutePlanRequest,
result: RoutePlanResult,
*,
resolved_origin: ResolvedPoint | None,
resolved_stops: list[ResolvedPoint],
resolved_destination: ResolvedPoint,
) -> DeepLinks:
if not request.need_deep_link:
raise GuardrailError("need_deep_link must be true because this service requires deep link output")
ordered_points = _ordered_points_for_best_route(
request,
result,
resolved_origin=resolved_origin,
resolved_stops=resolved_stops,
resolved_destination=resolved_destination,
)
for point in ordered_points:
if not point.poi_id:
raise GuardrailError(
f"Deep link generation requires poi_id for every point; missing poi_id for {point.input_address}"
)
server = _build_mcp_server()
async with server:
personal_map = await _call_mcp_tool(
server,
"maps_schema_personal_map",
{
"orgName": "geo-agent",
"lineList": [
{
"title": request.task_name,
"pointInfoList": [
{
"name": point.resolved_name,
"lon": point.lon,
"lat": point.lat,
"poiId": point.poi_id,
}
for point in ordered_points
],
}
],
},
)
if not isinstance(personal_map, str) or not personal_map.strip():
raise GuardrailError("Deep link generation failed: maps_schema_personal_map returned no usable URI")
return DeepLinks(personal_map=personal_map.strip())
def _configured_max_permutations() -> int:
return _env_positive_int("ROUTE_MAX_PERMUTATIONS", 20)
def _candidate_count(stop_count: int) -> int:
return math.factorial(stop_count)
def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest:
if not request.need_deep_link:
raise GuardrailError("need_deep_link must be true because this service requires deep link output")
configured_limit = _configured_max_permutations()
effective_limit = request.max_permutations or configured_limit
if effective_limit > configured_limit:
raise GuardrailError(
"Requested max_permutations exceeds the configured service limit: "
f"requested={effective_limit}, configured={configured_limit}"
)
candidate_count = _candidate_count(len(request.stops))
if candidate_count > effective_limit:
raise GuardrailError(
"Candidate permutations exceed the configured limit: "
f"stops={len(request.stops)}, permutations={candidate_count}, limit={effective_limit}"
)
return request.model_copy(update={"max_permutations": effective_limit})
def _build_user_prompt(
request: RoutePlanRequest,
*,
resolved_origin: ResolvedPoint | None,
resolved_stops: list[ResolvedPoint],
resolved_destination: ResolvedPoint,
) -> str:
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
resolved_points_json = json.dumps(
_resolved_points_payload(
resolved_origin=resolved_origin,
resolved_stops=resolved_stops,
resolved_destination=resolved_destination,
),
ensure_ascii=False,
indent=2,
)
return (
"请根据下面的 RoutePlanRequest JSON 执行多目标路线规划。\n"
"所有点位已经由服务端完成严格解析和 POI 校验,且这些点位是唯一可信输入。\n"
"不要再调用 maps_geo、maps_text_search、maps_search_detail也不要生成 deep link。\n"
"你必须基于这些已解析点位,仅使用 maps_direction_driving 计算逐段路线,并选择 best_route。\n"
"如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location则不得伪造起点坐标。\n"
f"本次运行的候选顺序硬上限是 {request.max_permutations}\n\n"
"RoutePlanRequest JSON:\n"
f"{request_json}\n\n"
"Pre-resolved Points JSON:\n"
f"{resolved_points_json}"
)
def _same_candidate(left_candidate, right_candidate) -> bool:
return (
left_candidate.full_order_labels == right_candidate.full_order_labels
and left_candidate.total_distance_m == right_candidate.total_distance_m
and left_candidate.total_duration_s == right_candidate.total_duration_s
and len(left_candidate.legs) == len(right_candidate.legs)
)
def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> RoutePlanResult:
if not result.success:
raise GuardrailError("Route planning did not complete successfully")
if result.origin_mode != request.origin_mode:
raise GuardrailError("Agent output origin_mode does not match the request")
if result.resolved_destination.role != "destination":
raise GuardrailError("resolved_destination.role must be 'destination'")
if len(result.resolved_stops) != len(request.stops):
raise GuardrailError("resolved_stops count does not match the request")
if any(stop.role != "stop" for stop in result.resolved_stops):
raise GuardrailError("All resolved_stops entries must have role='stop'")
if request.origin_mode == "fixed":
if result.resolved_origin is None:
raise GuardrailError("resolved_origin is required when origin_mode='fixed'")
if result.resolved_origin.role != "origin":
raise GuardrailError("resolved_origin.role must be 'origin'")
else:
if result.resolved_origin is not None:
raise GuardrailError("resolved_origin must be null when origin_mode='current_location'")
if not result.candidates:
raise GuardrailError("A successful result must include at least one candidate route")
if len(result.candidates) > (request.max_permutations or 0):
raise GuardrailError("Agent returned more candidates than allowed by max_permutations")
matching_candidate = next(
(
candidate
for candidate in result.candidates
if _same_candidate(candidate, result.best_route)
),
None,
)
if matching_candidate is None:
raise GuardrailError("best_route must be one of the candidates")
if result.deep_links is None:
raise GuardrailError("A successful result must include deep_links")
if not any(
[
result.deep_links.personal_map,
result.deep_links.android_route_plan,
result.deep_links.ios_route_plan,
]
):
raise GuardrailError("A successful result must include at least one deep link")
all_points = [*result.resolved_stops, result.resolved_destination]
if result.resolved_origin is not None:
all_points.append(result.resolved_origin)
if any(not point.poi_id for point in all_points):
raise GuardrailError("All resolved points must include poi_id on successful results")
return result
def create_geo_agent() -> Agent[None, RoutePlanResult]:
return Agent(
model=_build_model(),
toolsets=[_build_mcp_server()],
output_type=RoutePlanResult,
system_prompt=SYSTEM_PROMPT,
)
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def run_route_plan(request: RoutePlanRequest) -> RoutePlanResult:
"""Execute the route planning agent for a given request."""
prepared_request = _prepare_request(request)
resolved_origin, resolved_stops, resolved_destination = await _resolve_request_points(prepared_request)
agent = create_geo_agent()
async with agent:
result = await agent.run(
_build_user_prompt(
prepared_request,
resolved_origin=resolved_origin,
resolved_stops=resolved_stops,
resolved_destination=resolved_destination,
)
)
output = result.output.model_copy(
update={
"resolved_origin": resolved_origin,
"resolved_stops": resolved_stops,
"resolved_destination": resolved_destination,
}
)
output = output.model_copy(
update={
"deep_links": await _build_required_deep_links(
prepared_request,
output,
resolved_origin=resolved_origin,
resolved_stops=resolved_stops,
resolved_destination=resolved_destination,
)
}
)
return _validate_result(prepared_request, output)