From 321c550a660e994fab5ec86adefa1d534f82c78b Mon Sep 17 00:00:00 2001 From: stupid_run Date: Mon, 13 Apr 2026 19:48:44 +0800 Subject: [PATCH] feat: poi id required --- agent.py | 515 +++++++++++++++++++++++++++++++++++-- docs/2026-04-13_summary.md | 1 + docs/agent_design.md | 24 ++ docs/frontend_api.md | 8 +- 4 files changed, 524 insertions(+), 24 deletions(-) diff --git a/agent.py b/agent.py index 40e7a2e..75c03e2 100644 --- a/agent.py +++ b/agent.py @@ -1,15 +1,18 @@ import json import math import os +import re +from typing import Any from dotenv import load_dotenv from openai import AsyncOpenAI from pydantic_ai import Agent +from pydantic_ai.exceptions import ModelRetry from pydantic_ai.mcp import MCPServerHTTP, MCPServerSSE, MCPServerStreamableHTTP from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider -from schemas import RoutePlanRequest, RoutePlanResult +from schemas import DeepLinks, ResolvedPoint, RoutePlanRequest, RoutePlanResult load_dotenv() @@ -21,6 +24,10 @@ class ConfigurationError(RuntimeError): class GuardrailError(RuntimeError): pass + +def _tool_args(**kwargs: Any) -> dict[str, Any]: + return {key: value for key, value in kwargs.items() if value is not None} + def _required_env(name: str) -> str: value = os.getenv(name, "").strip() if not value: @@ -130,11 +137,405 @@ SYSTEM_PROMPT = """\ 10. 如果任何地址无法可靠解析、任何 POI 无法获取、或任何路线比较存在信息缺失,必须如实说明,不得猜测。 11. 你的输入会以 RoutePlanRequest JSON 形式提供。你必须基于该 JSON 进行规划,不能擅自补造字段。 12. 在返回结果前,必须确认 best_route 来自 candidates 之一,且每一段 distance 和 duration 都来自工具调用。 +13. 如果输入中已经提供“预解析点位 JSON”,则这些点位是服务端已验证的唯一可信事实来源。此时不要再调用 maps_geo、maps_text_search、maps_search_detail,也不要重新挑选 POI。 +14. 当预解析点位已提供时,你只需要基于这些点位生成候选顺序、调用 maps_direction_driving 逐段计算、挑选 best_route,并产出 summary 与 warnings。 +15. 当预解析点位已提供时,deep_links 字段保持为 null,由服务端在结果通过校验后统一生成。 你的底层返回必须是结构化数据,不是 HTML。只有在用户明确要求页面展示时,才在结构化结果基础上额外生成 HTML。 """ +_SUB_POI_TOKENS = { + "停车场", + "入口", + "出口", + "住院部", + "门诊", + "外科", + "内科", + "体检", + "饭堂", + "发热门诊", + "楼", + "中心", +} + + +def _normalize_text(value: str | None) -> str: + if not value: + return "" + return re.sub(r"[\s\-_,,。.;;::/\\()()\[\]【】{}·]", "", value).casefold() + + +def _parse_location(location: str) -> tuple[float, float]: + try: + lon_text, lat_text = [item.strip() for item in location.split(",", 1)] + return float(lon_text), float(lat_text) + except Exception as exc: + raise GuardrailError(f"Invalid location value returned by maps service: {location}") from exc + + +def _haversine_m(lon1: float, lat1: float, lon2: float, lat2: float) -> float: + radius_m = 6371000 + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + d_phi = math.radians(lat2 - lat1) + d_lambda = math.radians(lon2 - lon1) + + a = math.sin(d_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2) ** 2 + return 2 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + + +def _score_poi_candidate( + poi: dict[str, Any], + *, + query_text: str, + address_text: str, + raw_query: str, + index: int, +) -> int: + name = _normalize_text(str(poi.get("name", ""))) + address = _normalize_text(str(poi.get("address", ""))) + score = 0 + + if query_text and name == query_text: + score += 120 + elif query_text and query_text in name: + score += 90 + elif query_text and name and name in query_text: + score += 45 + + if address_text and address == address_text: + score += 80 + elif address_text and address_text in address: + score += 45 + + if poi.get("id"): + score += 20 + + display_name = str(poi.get("name", "")) + if any(token in display_name and token not in raw_query for token in _SUB_POI_TOKENS): + score -= 25 + + score -= min(index, 10) + return score + + +def _select_precise_poi( + pois: list[dict[str, Any]], + *, + input_name: str | None, + input_address: str, +) -> dict[str, Any]: + if not pois: + raise GuardrailError(f"No POI candidates found for address: {input_address}") + + raw_query = input_name or input_address + query_text = _normalize_text(raw_query) + address_text = _normalize_text(input_address) + + ranked = [ + ( + _score_poi_candidate( + poi, + query_text=query_text, + address_text=address_text, + raw_query=raw_query, + index=index, + ), + index, + poi, + ) + for index, poi in enumerate(pois) + ] + ranked.sort(key=lambda item: (item[0], -item[1]), reverse=True) + + best_score, _, best_poi = ranked[0] + if best_score < 80: + raise GuardrailError( + f"POI precision is insufficient for address: {input_address}; best score={best_score}" + ) + + if len(ranked) > 1: + second_score, _, second_poi = ranked[1] + if second_poi.get("id") != best_poi.get("id") and second_score >= best_score - 10: + raise GuardrailError( + f"POI match is ambiguous for address: {input_address}; multiple close candidates found" + ) + + poi_id = str(best_poi.get("id") or "").strip() + if not poi_id: + raise GuardrailError(f"POI ID is missing for address: {input_address}") + + return best_poi + + +async def _call_mcp_tool(server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP, name: str, args: dict[str, Any]) -> Any: + try: + return await server.direct_call_tool(name, args) + except ModelRetry as exc: + if "Timed out while waiting" in str(exc): + raise TimeoutError(str(exc)) from exc + raise GuardrailError(f"MCP tool call failed for {name}: {exc}") from exc + + +def _closest_geo_result( + geo_results: list[dict[str, Any]], + *, + lon: float, + lat: float, +) -> dict[str, Any] | None: + closest_result: dict[str, Any] | None = None + closest_distance: float | None = None + + for item in geo_results: + location = str(item.get("location") or "").strip() + if not location: + continue + candidate_lon, candidate_lat = _parse_location(location) + distance = _haversine_m(lon, lat, candidate_lon, candidate_lat) + if closest_distance is None or distance < closest_distance: + closest_distance = distance + closest_result = item + + if closest_distance is not None and closest_distance > 500: + raise GuardrailError("Geocode and POI detail locations disagree beyond acceptable precision") + + return closest_result + + +async def _resolve_point( + server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP, + *, + role: str, + input_name: str | None, + input_address: str, + city: str | None, +) -> ResolvedPoint: + geo_response = await _call_mcp_tool( + server, + "maps_geo", + _tool_args(address=input_address, city=city), + ) + text_search_response = await _call_mcp_tool( + server, + "maps_text_search", + _tool_args(keywords=input_name or input_address, city=city), + ) + + geo_results = list((geo_response or {}).get("results") or []) + pois = list((text_search_response or {}).get("pois") or []) + selected_poi = _select_precise_poi(pois, input_name=input_name, input_address=input_address) + + detail = await _call_mcp_tool( + server, + "maps_search_detail", + {"id": selected_poi["id"]}, + ) + + detail_location = str(detail.get("location") or "").strip() + if not detail_location: + raise GuardrailError(f"Resolved POI detail has no location for address: {input_address}") + + lon, lat = _parse_location(detail_location) + closest_geo = _closest_geo_result(geo_results, lon=lon, lat=lat) if geo_results else None + district = None + if closest_geo is not None: + district = str(closest_geo.get("district") or "").strip() or None + + resolved_name = str(detail.get("name") or selected_poi.get("name") or "").strip() + poi_id = str(detail.get("id") or selected_poi.get("id") or "").strip() + resolved_city = str(detail.get("city") or city or "").strip() or city + + if not resolved_name: + raise GuardrailError(f"Resolved POI has no name for address: {input_address}") + if not poi_id: + raise GuardrailError(f"Resolved POI has no poi_id for address: {input_address}") + + return ResolvedPoint( + role=role, # type: ignore[arg-type] + input_name=input_name, + input_address=input_address, + resolved_name=resolved_name, + city=resolved_city, + district=district, + location=detail_location, + lon=lon, + lat=lat, + poi_id=poi_id, + source="search_detail", + confidence_note="Prevalidated exact POI with confirmed poi_id for deep link generation", + ) + + +async def _resolve_request_points( + request: RoutePlanRequest, +) -> tuple[ResolvedPoint | None, list[ResolvedPoint], ResolvedPoint]: + cache: dict[tuple[str | None, str, str | None], ResolvedPoint] = {} + server = _build_mcp_server() + + async with server: + async def resolve_cached(*, role: str, input_name: str | None, input_address: str, city: str | None) -> ResolvedPoint: + cache_key = (input_name, input_address, city) + cached = cache.get(cache_key) + if cached is None: + cached = await _resolve_point( + server, + role=role, + input_name=input_name, + input_address=input_address, + city=city, + ) + cache[cache_key] = cached + + return cached.model_copy( + update={ + "role": role, + "input_name": input_name, + "input_address": input_address, + } + ) + + resolved_origin: ResolvedPoint | None = None + if request.origin_mode == "fixed" and request.origin_address is not None: + resolved_origin = await resolve_cached( + role="origin", + input_name=request.origin_name, + input_address=request.origin_address, + city=request.origin_city, + ) + + resolved_stops = [ + await resolve_cached( + role="stop", + input_name=stop.name, + input_address=stop.address, + city=stop.city, + ) + for stop in request.stops + ] + + resolved_destination = await resolve_cached( + role="destination", + input_name=request.destination_name, + input_address=request.destination_address, + city=request.destination_city, + ) + + return resolved_origin, resolved_stops, resolved_destination + + +def _resolved_points_payload( + *, + resolved_origin: ResolvedPoint | None, + resolved_stops: list[ResolvedPoint], + resolved_destination: ResolvedPoint, +) -> dict[str, Any]: + return { + "resolved_origin": resolved_origin.model_dump(mode="json") if resolved_origin else None, + "resolved_stops": [point.model_dump(mode="json") for point in resolved_stops], + "resolved_destination": resolved_destination.model_dump(mode="json"), + } + + +def _point_label_candidates(point: ResolvedPoint) -> list[str]: + labels = [point.resolved_name] + if point.input_name: + labels.append(point.input_name) + labels.append(point.input_address) + deduplicated: list[str] = [] + for label in labels: + if label not in deduplicated: + deduplicated.append(label) + return deduplicated + + +def _ordered_points_for_best_route( + request: RoutePlanRequest, + result: RoutePlanResult, + *, + resolved_origin: ResolvedPoint | None, + resolved_stops: list[ResolvedPoint], + resolved_destination: ResolvedPoint, +) -> list[ResolvedPoint]: + ordered_points: list[ResolvedPoint] = [] + if request.origin_mode == "fixed" and resolved_origin is not None: + ordered_points.append(resolved_origin) + + remaining_stops = resolved_stops.copy() + for label in result.best_route.stop_order_labels: + normalized_label = _normalize_text(label) + match_index = next( + ( + index + for index, point in enumerate(remaining_stops) + if normalized_label in {_normalize_text(candidate) for candidate in _point_label_candidates(point)} + ), + None, + ) + if match_index is None: + raise GuardrailError(f"Unable to map best_route stop label back to resolved stop: {label}") + ordered_points.append(remaining_stops.pop(match_index)) + + ordered_points.append(resolved_destination) + return ordered_points + + +async def _build_required_deep_links( + request: RoutePlanRequest, + result: RoutePlanResult, + *, + resolved_origin: ResolvedPoint | None, + resolved_stops: list[ResolvedPoint], + resolved_destination: ResolvedPoint, +) -> DeepLinks: + if not request.need_deep_link: + raise GuardrailError("need_deep_link must be true because this service requires deep link output") + + ordered_points = _ordered_points_for_best_route( + request, + result, + resolved_origin=resolved_origin, + resolved_stops=resolved_stops, + resolved_destination=resolved_destination, + ) + + for point in ordered_points: + if not point.poi_id: + raise GuardrailError( + f"Deep link generation requires poi_id for every point; missing poi_id for {point.input_address}" + ) + + server = _build_mcp_server() + async with server: + personal_map = await _call_mcp_tool( + server, + "maps_schema_personal_map", + { + "orgName": "geo-agent", + "lineList": [ + { + "title": request.task_name, + "pointInfoList": [ + { + "name": point.resolved_name, + "lon": point.lon, + "lat": point.lat, + "poiId": point.poi_id, + } + for point in ordered_points + ], + } + ], + }, + ) + + if not isinstance(personal_map, str) or not personal_map.strip(): + raise GuardrailError("Deep link generation failed: maps_schema_personal_map returned no usable URI") + + return DeepLinks(personal_map=personal_map.strip()) + + def _configured_max_permutations() -> int: return _env_positive_int("ROUTE_MAX_PERMUTATIONS", 20) @@ -144,6 +545,9 @@ def _candidate_count(stop_count: int) -> int: def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest: + if not request.need_deep_link: + raise GuardrailError("need_deep_link must be true because this service requires deep link output") + configured_limit = _configured_max_permutations() effective_limit = request.max_permutations or configured_limit @@ -163,16 +567,34 @@ def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest: return request.model_copy(update={"max_permutations": effective_limit}) -def _build_user_prompt(request: RoutePlanRequest) -> str: +def _build_user_prompt( + request: RoutePlanRequest, + *, + resolved_origin: ResolvedPoint | None, + resolved_stops: list[ResolvedPoint], + resolved_destination: ResolvedPoint, +) -> str: request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2) + resolved_points_json = json.dumps( + _resolved_points_payload( + resolved_origin=resolved_origin, + resolved_stops=resolved_stops, + resolved_destination=resolved_destination, + ), + ensure_ascii=False, + indent=2, + ) return ( "请根据下面的 RoutePlanRequest JSON 执行多目标路线规划。\n" - "必须显式使用高德 MCP 工具完成地址解析、逐段驾车路线计算和 deep link 生成。\n" - "如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location," - "则不得伪造起点坐标。\n" + "所有点位已经由服务端完成严格解析和 POI 校验,且这些点位是唯一可信输入。\n" + "不要再调用 maps_geo、maps_text_search、maps_search_detail,也不要生成 deep link。\n" + "你必须基于这些已解析点位,仅使用 maps_direction_driving 计算逐段路线,并选择 best_route。\n" + "如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location,则不得伪造起点坐标。\n" f"本次运行的候选顺序硬上限是 {request.max_permutations}。\n\n" "RoutePlanRequest JSON:\n" - f"{request_json}" + f"{request_json}\n\n" + "Pre-resolved Points JSON:\n" + f"{resolved_points_json}" ) @@ -186,6 +608,9 @@ def _same_candidate(left_candidate, right_candidate) -> bool: def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> RoutePlanResult: + if not result.success: + raise GuardrailError("Route planning did not complete successfully") + if result.origin_mode != request.origin_mode: raise GuardrailError("Agent output origin_mode does not match the request") @@ -207,23 +632,40 @@ def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> Rout if result.resolved_origin is not None: raise GuardrailError("resolved_origin must be null when origin_mode='current_location'") - if result.success: - if not result.candidates: - raise GuardrailError("A successful result must include at least one candidate route") + if not result.candidates: + raise GuardrailError("A successful result must include at least one candidate route") - if len(result.candidates) > (request.max_permutations or 0): - raise GuardrailError("Agent returned more candidates than allowed by max_permutations") + if len(result.candidates) > (request.max_permutations or 0): + raise GuardrailError("Agent returned more candidates than allowed by max_permutations") - matching_candidate = next( - ( - candidate - for candidate in result.candidates - if _same_candidate(candidate, result.best_route) - ), - None, - ) - if matching_candidate is None: - raise GuardrailError("best_route must be one of the candidates") + matching_candidate = next( + ( + candidate + for candidate in result.candidates + if _same_candidate(candidate, result.best_route) + ), + None, + ) + if matching_candidate is None: + raise GuardrailError("best_route must be one of the candidates") + + if result.deep_links is None: + raise GuardrailError("A successful result must include deep_links") + + if not any( + [ + result.deep_links.personal_map, + result.deep_links.android_route_plan, + result.deep_links.ios_route_plan, + ] + ): + raise GuardrailError("A successful result must include at least one deep link") + + all_points = [*result.resolved_stops, result.resolved_destination] + if result.resolved_origin is not None: + all_points.append(result.resolved_origin) + if any(not point.poi_id for point in all_points): + raise GuardrailError("All resolved points must include poi_id on successful results") return result @@ -243,7 +685,34 @@ def create_geo_agent() -> Agent[None, RoutePlanResult]: async def run_route_plan(request: RoutePlanRequest) -> RoutePlanResult: """Execute the route planning agent for a given request.""" prepared_request = _prepare_request(request) + resolved_origin, resolved_stops, resolved_destination = await _resolve_request_points(prepared_request) agent = create_geo_agent() async with agent: - result = await agent.run(_build_user_prompt(prepared_request)) - return _validate_result(prepared_request, result.output) + result = await agent.run( + _build_user_prompt( + prepared_request, + resolved_origin=resolved_origin, + resolved_stops=resolved_stops, + resolved_destination=resolved_destination, + ) + ) + + output = result.output.model_copy( + update={ + "resolved_origin": resolved_origin, + "resolved_stops": resolved_stops, + "resolved_destination": resolved_destination, + } + ) + output = output.model_copy( + update={ + "deep_links": await _build_required_deep_links( + prepared_request, + output, + resolved_origin=resolved_origin, + resolved_stops=resolved_stops, + resolved_destination=resolved_destination, + ) + } + ) + return _validate_result(prepared_request, output) diff --git a/docs/2026-04-13_summary.md b/docs/2026-04-13_summary.md index 2651fc1..d9d5bc6 100644 --- a/docs/2026-04-13_summary.md +++ b/docs/2026-04-13_summary.md @@ -38,6 +38,7 @@ - 已补充上游 timeout 配置和 504 错误映射,避免外部超时被混淆为普通 500。 - 已修正 `stops` 非空校验,并更新前端文档中 `deep_links` 与 `summary` 的语义边界说明。 - 已为 FastAPI 增加可配置 CORS 中间件,默认允许本地 `localhost/127.0.0.1` 任意端口联调。 +- 已切换到严格 deep-link 模式:所有点必须先完成 POI 校验并拿到 `poi_id`,成功结果必须包含 deep link,否则直接失败。 ## 下一步建议 diff --git a/docs/agent_design.md b/docs/agent_design.md index 87ffd71..55f0a8a 100644 --- a/docs/agent_design.md +++ b/docs/agent_design.md @@ -119,6 +119,23 @@ 这样做的目的是把“长期规则”和“本次任务上下文”拆开,减少 prompt 污染。 +### 5.3 当前严格执行流 + +当前版本已不再把“地址解析是否足够精确”和“deep link 是否能生成”完全交给模型决定。 + +当前执行流如下: + +1. 服务端先对起点、终点、所有途经点做前置解析 +2. 每个点都必须成功完成: + - `maps_text_search` 命中精确 POI + - `maps_search_detail` 返回稳定坐标 + - 取得 `poi_id` +3. 任意一个点解析模糊、缺少 `poi_id`、或地理结果交叉校验失败,直接返回错误 +4. 只有在所有点都通过后,才把“预解析点位 JSON”交给 Agent +5. Agent 只负责候选顺序、逐段驾车计算、最佳路线选择、summary 和 warnings +6. 服务端最后再直接调用 `maps_schema_personal_map` 生成 deep link +7. 如果 deep link 生成失败,则整个请求失败 + ## 6. 已实现的代码护栏 ### 6.1 输入护栏 @@ -132,6 +149,7 @@ - `origin_mode=fixed` 时必须提供 `origin_address` - 终点地址不能同时出现在 `stops` - `max_permutations` 如果传入,必须大于 0 +- `need_deep_link` 必须为 `true` ### 6.2 执行规模护栏 @@ -157,6 +175,9 @@ - `origin_mode=current_location` 时禁止返回固定 `resolved_origin` - 成功结果必须至少有一个 candidate - `best_route` 必须能在 `candidates` 中找到对应项 +- `success` 必须为 `true` +- 成功结果必须包含至少一个 deep link +- 成功结果中的所有点必须带有 `poi_id` ### 6.4 配置护栏 @@ -188,6 +209,8 @@ - 高德远程 MCP 连通性已验证 - 单个途经点请求可成功返回结构化结果 - 超限请求可返回 422,并中止模型执行 +- 已改为严格 deep-link 模式:成功结果必须包含 deep link,否则直接失败 +- 已增加前置点位解析与 POI 校验阶段,缺少 `poi_id` 或命中模糊时直接失败 当前尚未完成: @@ -227,6 +250,7 @@ - `current_location` 模式尚未做专门增强 - `need_html` 目前尚未实现独立展示层 - 没有缓存机制,请求成本与工具调用次数直接相关 +- 当前严格策略可能会拒绝一部分“人看上去可接受、但程序判断为不够精确”的地址输入 ## 10. 下一阶段 TODO diff --git a/docs/frontend_api.md b/docs/frontend_api.md index 2e48263..7227760 100644 --- a/docs/frontend_api.md +++ b/docs/frontend_api.md @@ -111,7 +111,8 @@ Content-Type: application/json - 当前固定为 `driving` - `need_deep_link` - 可选 - - 是否需要生成 deep link + - 当前必须为 `true` + - 该服务的成功结果必须包含 deep link - `deep_link_mode` - 可选 - 可选值:`personal_map`、`route_plan`、`auto` @@ -260,6 +261,7 @@ Content-Type: application/json - `deep_links` - 给前端做按钮跳转使用 - 这是唯一应被前端当作链接处理的字段 + - 当前成功结果至少会包含 `personal_map` - `summary` - 可直接展示给用户的简要说明 - 这是纯展示文案,不是结构化链接字段,也不应被前端解析为跳转地址 @@ -324,6 +326,7 @@ Content-Type: application/json 补充说明: - `deep_links` 中可能同时存在多个字段,也可能只有其中一个字段有值 +- 当前实现中,成功结果会强制生成 `personal_map` - 前端应只根据 `deep_links` 的字段值控制按钮展示,不要依赖 `summary` 推断应展示哪个按钮 - `summary` 里可能会提到“个人地图链接”或“导航链接”,但这里只是说明文字,不保证包含真实 URL - 如果 `personal_map` 存在,表示当前更适合导入点位方案 @@ -343,8 +346,11 @@ Content-Type: application/json - 请求结构不合法 - `stops` 为空 +- `need_deep_link=false` - 固定起点缺少 `origin_address` - 终点同时出现在 `stops` +- 任意点未能解析到足够精确的 POI +- 任意点缺少 `poi_id` - 请求候选上限超过服务上限 - 实际排列数超过上限