feat: auth
This commit is contained in:
28
agent/__init__.py
Normal file
28
agent/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .load_plan import (
|
||||
LoadPlanConfigurationError,
|
||||
LoadPlanGuardrailError,
|
||||
create_load_plan_agent,
|
||||
run_load_plan,
|
||||
)
|
||||
from .load_plan_tools import (
|
||||
LoadPlanToolConfigurationError,
|
||||
LoadPlanToolRequestError,
|
||||
get_transport_vehicle_by_license_plate,
|
||||
list_unshipped_shipments,
|
||||
)
|
||||
from .route_plan import ConfigurationError, GuardrailError, create_geo_agent, run_route_plan
|
||||
|
||||
__all__ = [
|
||||
"ConfigurationError",
|
||||
"GuardrailError",
|
||||
"LoadPlanConfigurationError",
|
||||
"LoadPlanGuardrailError",
|
||||
"LoadPlanToolConfigurationError",
|
||||
"LoadPlanToolRequestError",
|
||||
"create_load_plan_agent",
|
||||
"create_geo_agent",
|
||||
"get_transport_vehicle_by_license_plate",
|
||||
"list_unshipped_shipments",
|
||||
"run_load_plan",
|
||||
"run_route_plan",
|
||||
]
|
||||
107
agent/load_plan.py
Normal file
107
agent/load_plan.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai_skills import SkillsToolset
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic_ai.models.openai import OpenAIModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
|
||||
from schemas import LoadPlanRequest, LoadPlanResult
|
||||
from .load_plan_tools import (
|
||||
get_transport_vehicle_by_license_plate,
|
||||
list_unshipped_shipments,
|
||||
)
|
||||
|
||||
|
||||
class LoadPlanConfigurationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class LoadPlanGuardrailError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _required_env(name: str) -> str:
|
||||
import os
|
||||
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
raise LoadPlanConfigurationError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def _env_positive_float(name: str, default: float) -> float:
|
||||
import os
|
||||
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = float(raw_value)
|
||||
except ValueError as exc:
|
||||
raise LoadPlanConfigurationError(f"Environment variable {name} must be a number") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise LoadPlanConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _build_model() -> OpenAIModel:
|
||||
openai_client = AsyncOpenAI(
|
||||
base_url=_required_env("LOAD_PLAN_BASE_URL"),
|
||||
api_key=_required_env("LOAD_PLAN_API_KEY"),
|
||||
timeout=_env_positive_float("LOAD_PLAN_REQUEST_TIMEOUT_SECONDS", 120.0),
|
||||
)
|
||||
provider = OpenAIProvider(
|
||||
openai_client=openai_client,
|
||||
)
|
||||
return OpenAIModel(
|
||||
_required_env("LOAD_PLAN_MODEL"),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
INSTRUCTIONS = (
|
||||
"你是一个装载规划 Agent。"
|
||||
"你的任务是根据请求中的 merchant_id、area 和 license_plate,"
|
||||
"调用工具获取待出货出货单与车辆容量,并严格遵守已加载的 skill 来选择出货单。"
|
||||
"你必须先获取出货单,再获取车辆容量。"
|
||||
"车辆容量按销售品条数计算,不按 quantity 米数计算。"
|
||||
"目标是在不超过车辆容量的前提下尽量多装;若无法满载,允许欠载,但必须给出 warning。"
|
||||
"不要臆造 sales_items、printing_job_width、面料类型、数量或容量。"
|
||||
"如果无法依据真实返回判断装载结果,就返回空的 selected_shipment_ids,并在 summary/warnings 中说明。"
|
||||
)
|
||||
|
||||
|
||||
def _build_user_prompt(request: LoadPlanRequest) -> str:
|
||||
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
|
||||
return (
|
||||
"请根据下面的 LoadPlanRequest JSON 执行装载规划。\n"
|
||||
"你必须调用工具获取真实数据,并根据 skill 中的规则完成判断。\n\n"
|
||||
"LoadPlanRequest JSON:\n"
|
||||
f"{request_json}"
|
||||
)
|
||||
|
||||
|
||||
def create_load_plan_agent() -> Agent[None, LoadPlanResult]:
|
||||
return Agent(
|
||||
model=_build_model(),
|
||||
instructions=INSTRUCTIONS,
|
||||
output_type=LoadPlanResult,
|
||||
tools=[
|
||||
get_transport_vehicle_by_license_plate,
|
||||
list_unshipped_shipments,
|
||||
],
|
||||
toolsets=[
|
||||
SkillsToolset(
|
||||
directories=["./skills"],
|
||||
auto_reload=True,
|
||||
exclude_tools={"run_skill_script"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def run_load_plan(request: LoadPlanRequest) -> LoadPlanResult:
|
||||
agent = create_load_plan_agent()
|
||||
async with agent:
|
||||
result = await agent.run(_build_user_prompt(request))
|
||||
return result.output
|
||||
112
agent/load_plan_tools.py
Normal file
112
agent/load_plan_tools.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from schemas import ShipmentPage, TransportVehicle
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class LoadPlanToolConfigurationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class LoadPlanToolRequestError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _required_env(name: str) -> str:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
raise LoadPlanToolConfigurationError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def _env_positive_float(name: str, default: float) -> float:
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = float(raw_value)
|
||||
except ValueError as exc:
|
||||
raise LoadPlanToolConfigurationError(f"Environment variable {name} must be a number") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise LoadPlanToolConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _api_base_url() -> str:
|
||||
return _required_env("LOAD_PLAN_API_HOST").rstrip("/")
|
||||
|
||||
|
||||
def _api_headers() -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": _required_env("LOAD_PLAN_AGENT_ACCESS_KEY"),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def _build_client() -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=_api_base_url(),
|
||||
headers=_api_headers(),
|
||||
timeout=_env_positive_float("LOAD_PLAN_API_TIMEOUT_SECONDS", 20.0),
|
||||
)
|
||||
|
||||
|
||||
def _raise_for_response(response: httpx.Response) -> None:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
body = exc.response.text.strip()
|
||||
detail = body or f"status={exc.response.status_code}"
|
||||
raise LoadPlanToolRequestError(
|
||||
f"Load-plan API request failed for {exc.request.method} {exc.request.url}: {detail}"
|
||||
) from exc
|
||||
|
||||
|
||||
async def get_transport_vehicle_by_license_plate(
|
||||
*,
|
||||
merchant_id: int,
|
||||
license_plate: str,
|
||||
) -> TransportVehicle:
|
||||
"""Fetch a specific transport vehicle for one merchant by license plate."""
|
||||
encoded_plate = quote(license_plate, safe="")
|
||||
path = f"/api/v2/ai/transport-vehicles/{encoded_plate}/"
|
||||
|
||||
async with _build_client() as client:
|
||||
response = await client.get(
|
||||
path,
|
||||
params={"merchant_id": merchant_id},
|
||||
)
|
||||
|
||||
_raise_for_response(response)
|
||||
return TransportVehicle.model_validate(response.json())
|
||||
|
||||
|
||||
async def list_unshipped_shipments(
|
||||
*,
|
||||
merchant_id: int,
|
||||
area: str,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> ShipmentPage:
|
||||
"""List unshipped shipments for one merchant and one exact area."""
|
||||
params: dict[str, int | str] = {
|
||||
"merchant_id": merchant_id,
|
||||
"area": area,
|
||||
}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if offset is not None:
|
||||
params["offset"] = offset
|
||||
|
||||
async with _build_client() as client:
|
||||
response = await client.get(
|
||||
"/api/v2/ai/shipments/unshipped/",
|
||||
params=params,
|
||||
)
|
||||
|
||||
_raise_for_response(response)
|
||||
return ShipmentPage.model_validate(response.json())
|
||||
53
agent/load_plan_tools_cli.py
Normal file
53
agent/load_plan_tools_cli.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from .load_plan_tools import (
|
||||
get_transport_vehicle_by_license_plate,
|
||||
list_unshipped_shipments,
|
||||
)
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Standalone CLI for load-plan API tools")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
vehicle_parser = subparsers.add_parser("vehicle", help="Fetch a vehicle by license plate")
|
||||
vehicle_parser.add_argument("--merchant-id", type=int, required=True)
|
||||
vehicle_parser.add_argument("--license-plate", required=True)
|
||||
|
||||
shipments_parser = subparsers.add_parser("shipments", help="List unshipped shipments by area")
|
||||
shipments_parser.add_argument("--merchant-id", type=int, required=True)
|
||||
shipments_parser.add_argument("--area", required=True)
|
||||
shipments_parser.add_argument("--limit", type=int)
|
||||
shipments_parser.add_argument("--offset", type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
async def _run() -> None:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "vehicle":
|
||||
result = await get_transport_vehicle_by_license_plate(
|
||||
merchant_id=args.merchant_id,
|
||||
license_plate=args.license_plate,
|
||||
)
|
||||
else:
|
||||
result = await list_unshipped_shipments(
|
||||
merchant_id=args.merchant_id,
|
||||
area=args.area,
|
||||
limit=args.limit,
|
||||
offset=args.offset,
|
||||
)
|
||||
|
||||
print(json.dumps(result.model_dump(mode="json"), ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
731
agent/route_plan.py
Normal file
731
agent/route_plan.py
Normal file
@@ -0,0 +1,731 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.exceptions import ModelRetry
|
||||
from pydantic_ai.mcp import MCPServerHTTP, MCPServerSSE, MCPServerStreamableHTTP
|
||||
from pydantic_ai.models.openai import OpenAIModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
|
||||
from schemas import DeepLinks, ResolvedPoint, RoutePlanRequest, RoutePlanResult
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ConfigurationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class GuardrailError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _tool_args(**kwargs: Any) -> dict[str, Any]:
|
||||
return {key: value for key, value in kwargs.items() if value is not None}
|
||||
|
||||
|
||||
def _required_env(name: str) -> str:
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
raise ConfigurationError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def _env_positive_int(name: str, default: int) -> int:
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = int(raw_value)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(f"Environment variable {name} must be an integer") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _env_positive_float(name: str, default: float) -> float:
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = float(raw_value)
|
||||
except ValueError as exc:
|
||||
raise ConfigurationError(f"Environment variable {name} must be a number") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise ConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _build_model() -> OpenAIModel:
|
||||
openai_client = AsyncOpenAI(
|
||||
base_url=_required_env("ARK_BASE_URL"),
|
||||
api_key=_required_env("ARK_API_KEY"),
|
||||
timeout=_env_positive_float("ARK_REQUEST_TIMEOUT_SECONDS", 120.0),
|
||||
)
|
||||
provider = OpenAIProvider(
|
||||
openai_client=openai_client,
|
||||
)
|
||||
return OpenAIModel(
|
||||
_required_env("ARK_MODEL"),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def _build_mcp_headers() -> dict[str, str] | None:
|
||||
header_name = os.getenv("AMAP_MCP_AUTH_HEADER_NAME", "").strip()
|
||||
header_value = os.getenv("AMAP_MCP_AUTH_HEADER_VALUE", "").strip()
|
||||
|
||||
if header_name and header_value:
|
||||
return {header_name: header_value}
|
||||
if header_name or header_value:
|
||||
raise ConfigurationError(
|
||||
"AMAP_MCP_AUTH_HEADER_NAME and AMAP_MCP_AUTH_HEADER_VALUE must be set together"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _build_mcp_server() -> MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP:
|
||||
transport = os.getenv("AMAP_MCP_TRANSPORT", "streamable_http").strip().lower()
|
||||
url = _required_env("AMAP_MCP_URL")
|
||||
headers = _build_mcp_headers()
|
||||
timeout = _env_positive_float("AMAP_MCP_TIMEOUT_SECONDS", 20.0)
|
||||
read_timeout = _env_positive_float("AMAP_MCP_READ_TIMEOUT_SECONDS", 60.0)
|
||||
|
||||
if transport == "sse":
|
||||
return MCPServerSSE(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
|
||||
if transport == "http":
|
||||
return MCPServerHTTP(url=url, headers=headers, timeout=timeout, read_timeout=read_timeout)
|
||||
if transport == "streamable_http":
|
||||
return MCPServerStreamableHTTP(
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
read_timeout=read_timeout,
|
||||
)
|
||||
|
||||
raise ConfigurationError(
|
||||
"AMAP_MCP_TRANSPORT must be one of: streamable_http, http, sse"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
你是一个无状态的多目标地理路线规划 Agent。你的任务是使用高德地图 MCP 工具完成多目标点最优路线规划,并返回强类型结构化结果。\
|
||||
你必须显式调用地图工具,不得编造坐标、POI、距离、时长或 deep link。
|
||||
|
||||
你的工作流必须严格遵守以下规则:
|
||||
|
||||
1. 先解析输入,明确起点模式、终点、途经点、优化策略和输出需求。
|
||||
2. 对终点和每个途经点优先使用 maps_geo 解析地址;当命中不准或为空时,使用 maps_text_search 补齐 POI;必要时使用 maps_search_detail 校验。
|
||||
3. 当起点模式为 fixed 时,对起点也做同样解析。
|
||||
4. 当起点模式为 current_location 时,不得伪造起点坐标;如果缺少实时定位坐标,只能比较途经点内部顺序,并明确说明真实最优路线会受当前定位影响。
|
||||
5. 这套工具没有单步多点最优路径工具,因此你必须自己生成候选途经点顺序。
|
||||
6. 对每个候选顺序,逐段调用 maps_direction_driving 计算距离与时长,并汇总为候选路线。
|
||||
7. 按 route_strategy 选择最优路线:
|
||||
- shortest_distance:总里程优先,总时长次优
|
||||
- fastest_time:总时长优先,总里程次优
|
||||
- balanced:优先选择明显不劣的 Pareto 优势路线;若出现里程更短但时间略长的情况,默认优先里程更短并说明原因
|
||||
8. 如需 deep link:
|
||||
- 若目标是稳定导入整组点位,可使用 maps_schema_personal_map,但必须说明这是点位导入型链接
|
||||
- 若目标是"我的位置"出发的即时导航,应输出 route plan 类 deep link,并说明不同平台协议不同
|
||||
9. 最终结果必须包含:resolved_origin(如适用)、resolved_destination、resolved_stops、candidates、best_route、deep_links、summary、warnings。
|
||||
10. 如果任何地址无法可靠解析、任何 POI 无法获取、或任何路线比较存在信息缺失,必须如实说明,不得猜测。
|
||||
11. 你的输入会以 RoutePlanRequest JSON 形式提供。你必须基于该 JSON 进行规划,不能擅自补造字段。
|
||||
12. 在返回结果前,必须确认 best_route 来自 candidates 之一,且每一段 distance 和 duration 都来自工具调用。
|
||||
13. 如果输入中已经提供“预解析点位 JSON”,则这些点位是服务端已验证的唯一可信事实来源。此时不要再调用 maps_geo、maps_text_search、maps_search_detail,也不要重新挑选 POI。
|
||||
14. 当预解析点位已提供时,你只需要基于这些点位生成候选顺序、调用 maps_direction_driving 逐段计算、挑选 best_route,并产出 summary 与 warnings。
|
||||
15. 当预解析点位已提供时,deep_links 字段保持为 null,由服务端在结果通过校验后统一生成。
|
||||
|
||||
你的底层返回必须是结构化数据,不是 HTML。只有在用户明确要求页面展示时,才在结构化结果基础上额外生成 HTML。
|
||||
"""
|
||||
|
||||
|
||||
_SUB_POI_TOKENS = {
|
||||
"停车场",
|
||||
"入口",
|
||||
"出口",
|
||||
"住院部",
|
||||
"门诊",
|
||||
"外科",
|
||||
"内科",
|
||||
"体检",
|
||||
"饭堂",
|
||||
"发热门诊",
|
||||
"楼",
|
||||
"中心",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_text(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
return re.sub(r"[\s\-_,,。.;;::/\\()()\[\]【】{}·]", "", value).casefold()
|
||||
|
||||
|
||||
def _parse_location(location: str) -> tuple[float, float]:
|
||||
try:
|
||||
lon_text, lat_text = [item.strip() for item in location.split(",", 1)]
|
||||
return float(lon_text), float(lat_text)
|
||||
except Exception as exc:
|
||||
raise GuardrailError(f"Invalid location value returned by maps service: {location}") from exc
|
||||
|
||||
|
||||
def _haversine_m(lon1: float, lat1: float, lon2: float, lat2: float) -> float:
|
||||
radius_m = 6371000
|
||||
phi1 = math.radians(lat1)
|
||||
phi2 = math.radians(lat2)
|
||||
d_phi = math.radians(lat2 - lat1)
|
||||
d_lambda = math.radians(lon2 - lon1)
|
||||
|
||||
a = math.sin(d_phi / 2) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(d_lambda / 2) ** 2
|
||||
return 2 * radius_m * math.atan2(math.sqrt(a), math.sqrt(1 - a))
|
||||
|
||||
|
||||
def _score_poi_candidate(
|
||||
poi: dict[str, Any],
|
||||
*,
|
||||
query_text: str,
|
||||
address_text: str,
|
||||
raw_query: str,
|
||||
index: int,
|
||||
) -> int:
|
||||
name = _normalize_text(str(poi.get("name", "")))
|
||||
address = _normalize_text(str(poi.get("address", "")))
|
||||
score = 0
|
||||
|
||||
if query_text and name == query_text:
|
||||
score += 120
|
||||
elif query_text and query_text in name:
|
||||
score += 90
|
||||
elif query_text and name and name in query_text:
|
||||
score += 45
|
||||
|
||||
if address_text and address == address_text:
|
||||
score += 80
|
||||
elif address_text and address_text in address:
|
||||
score += 45
|
||||
|
||||
if poi.get("id"):
|
||||
score += 20
|
||||
|
||||
display_name = str(poi.get("name", ""))
|
||||
if any(token in display_name and token not in raw_query for token in _SUB_POI_TOKENS):
|
||||
score -= 25
|
||||
|
||||
score -= min(index, 10)
|
||||
return score
|
||||
|
||||
|
||||
def _select_precise_poi(
|
||||
pois: list[dict[str, Any]],
|
||||
*,
|
||||
input_name: str | None,
|
||||
input_address: str,
|
||||
) -> dict[str, Any]:
|
||||
if not pois:
|
||||
raise GuardrailError(f"No POI candidates found for address: {input_address}")
|
||||
|
||||
raw_query = input_name or input_address
|
||||
query_text = _normalize_text(raw_query)
|
||||
address_text = _normalize_text(input_address)
|
||||
|
||||
ranked = [
|
||||
(
|
||||
_score_poi_candidate(
|
||||
poi,
|
||||
query_text=query_text,
|
||||
address_text=address_text,
|
||||
raw_query=raw_query,
|
||||
index=index,
|
||||
),
|
||||
index,
|
||||
poi,
|
||||
)
|
||||
for index, poi in enumerate(pois)
|
||||
]
|
||||
ranked.sort(key=lambda item: (item[0], -item[1]), reverse=True)
|
||||
|
||||
best_score, _, best_poi = ranked[0]
|
||||
if best_score < 80:
|
||||
raise GuardrailError(
|
||||
f"POI precision is insufficient for address: {input_address}; best score={best_score}"
|
||||
)
|
||||
|
||||
if len(ranked) > 1:
|
||||
second_score, _, second_poi = ranked[1]
|
||||
if second_poi.get("id") != best_poi.get("id") and second_score >= best_score - 10:
|
||||
raise GuardrailError(
|
||||
f"POI match is ambiguous for address: {input_address}; multiple close candidates found"
|
||||
)
|
||||
|
||||
poi_id = str(best_poi.get("id") or "").strip()
|
||||
if not poi_id:
|
||||
raise GuardrailError(f"POI ID is missing for address: {input_address}")
|
||||
|
||||
return best_poi
|
||||
|
||||
|
||||
async def _call_mcp_tool(
|
||||
server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP,
|
||||
name: str,
|
||||
args: dict[str, Any],
|
||||
) -> Any:
|
||||
try:
|
||||
return await server.direct_call_tool(name, args)
|
||||
except ModelRetry as exc:
|
||||
if "Timed out while waiting" in str(exc):
|
||||
raise TimeoutError(str(exc)) from exc
|
||||
raise GuardrailError(f"MCP tool call failed for {name}: {exc}") from exc
|
||||
|
||||
|
||||
def _closest_geo_result(
|
||||
geo_results: list[dict[str, Any]],
|
||||
*,
|
||||
lon: float,
|
||||
lat: float,
|
||||
) -> dict[str, Any] | None:
|
||||
closest_result: dict[str, Any] | None = None
|
||||
closest_distance: float | None = None
|
||||
|
||||
for item in geo_results:
|
||||
location = str(item.get("location") or "").strip()
|
||||
if not location:
|
||||
continue
|
||||
candidate_lon, candidate_lat = _parse_location(location)
|
||||
distance = _haversine_m(lon, lat, candidate_lon, candidate_lat)
|
||||
if closest_distance is None or distance < closest_distance:
|
||||
closest_distance = distance
|
||||
closest_result = item
|
||||
|
||||
if closest_distance is not None and closest_distance > 500:
|
||||
raise GuardrailError("Geocode and POI detail locations disagree beyond acceptable precision")
|
||||
|
||||
return closest_result
|
||||
|
||||
|
||||
async def _resolve_point(
|
||||
server: MCPServerHTTP | MCPServerSSE | MCPServerStreamableHTTP,
|
||||
*,
|
||||
role: str,
|
||||
input_name: str | None,
|
||||
input_address: str,
|
||||
city: str | None,
|
||||
) -> ResolvedPoint:
|
||||
geo_response = await _call_mcp_tool(
|
||||
server,
|
||||
"maps_geo",
|
||||
_tool_args(address=input_address, city=city),
|
||||
)
|
||||
text_search_response = await _call_mcp_tool(
|
||||
server,
|
||||
"maps_text_search",
|
||||
_tool_args(keywords=input_name or input_address, city=city),
|
||||
)
|
||||
|
||||
geo_results = list((geo_response or {}).get("results") or [])
|
||||
pois = list((text_search_response or {}).get("pois") or [])
|
||||
selected_poi = _select_precise_poi(pois, input_name=input_name, input_address=input_address)
|
||||
|
||||
detail = await _call_mcp_tool(
|
||||
server,
|
||||
"maps_search_detail",
|
||||
{"id": selected_poi["id"]},
|
||||
)
|
||||
|
||||
detail_location = str(detail.get("location") or "").strip()
|
||||
if not detail_location:
|
||||
raise GuardrailError(f"Resolved POI detail has no location for address: {input_address}")
|
||||
|
||||
lon, lat = _parse_location(detail_location)
|
||||
closest_geo = _closest_geo_result(geo_results, lon=lon, lat=lat) if geo_results else None
|
||||
district = None
|
||||
if closest_geo is not None:
|
||||
district = str(closest_geo.get("district") or "").strip() or None
|
||||
|
||||
resolved_name = str(detail.get("name") or selected_poi.get("name") or "").strip()
|
||||
poi_id = str(detail.get("id") or selected_poi.get("id") or "").strip()
|
||||
resolved_city = str(detail.get("city") or city or "").strip() or city
|
||||
|
||||
if not resolved_name:
|
||||
raise GuardrailError(f"Resolved POI has no name for address: {input_address}")
|
||||
if not poi_id:
|
||||
raise GuardrailError(f"Resolved POI has no poi_id for address: {input_address}")
|
||||
|
||||
return ResolvedPoint(
|
||||
role=role, # type: ignore[arg-type]
|
||||
input_name=input_name,
|
||||
input_address=input_address,
|
||||
resolved_name=resolved_name,
|
||||
city=resolved_city,
|
||||
district=district,
|
||||
location=detail_location,
|
||||
lon=lon,
|
||||
lat=lat,
|
||||
poi_id=poi_id,
|
||||
source="search_detail",
|
||||
confidence_note="Prevalidated exact POI with confirmed poi_id for deep link generation",
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_request_points(
|
||||
request: RoutePlanRequest,
|
||||
) -> tuple[ResolvedPoint | None, list[ResolvedPoint], ResolvedPoint]:
|
||||
cache: dict[tuple[str | None, str, str | None], ResolvedPoint] = {}
|
||||
server = _build_mcp_server()
|
||||
|
||||
async with server:
|
||||
async def resolve_cached(
|
||||
*,
|
||||
role: str,
|
||||
input_name: str | None,
|
||||
input_address: str,
|
||||
city: str | None,
|
||||
) -> ResolvedPoint:
|
||||
cache_key = (input_name, input_address, city)
|
||||
cached = cache.get(cache_key)
|
||||
if cached is None:
|
||||
cached = await _resolve_point(
|
||||
server,
|
||||
role=role,
|
||||
input_name=input_name,
|
||||
input_address=input_address,
|
||||
city=city,
|
||||
)
|
||||
cache[cache_key] = cached
|
||||
|
||||
return cached.model_copy(
|
||||
update={
|
||||
"role": role,
|
||||
"input_name": input_name,
|
||||
"input_address": input_address,
|
||||
}
|
||||
)
|
||||
|
||||
resolved_origin: ResolvedPoint | None = None
|
||||
if request.origin_mode == "fixed" and request.origin_address is not None:
|
||||
resolved_origin = await resolve_cached(
|
||||
role="origin",
|
||||
input_name=request.origin_name,
|
||||
input_address=request.origin_address,
|
||||
city=request.origin_city,
|
||||
)
|
||||
|
||||
resolved_stops = [
|
||||
await resolve_cached(
|
||||
role="stop",
|
||||
input_name=stop.name,
|
||||
input_address=stop.address,
|
||||
city=stop.city,
|
||||
)
|
||||
for stop in request.stops
|
||||
]
|
||||
|
||||
resolved_destination = await resolve_cached(
|
||||
role="destination",
|
||||
input_name=request.destination_name,
|
||||
input_address=request.destination_address,
|
||||
city=request.destination_city,
|
||||
)
|
||||
|
||||
return resolved_origin, resolved_stops, resolved_destination
|
||||
|
||||
|
||||
def _resolved_points_payload(
|
||||
*,
|
||||
resolved_origin: ResolvedPoint | None,
|
||||
resolved_stops: list[ResolvedPoint],
|
||||
resolved_destination: ResolvedPoint,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"resolved_origin": resolved_origin.model_dump(mode="json") if resolved_origin else None,
|
||||
"resolved_stops": [point.model_dump(mode="json") for point in resolved_stops],
|
||||
"resolved_destination": resolved_destination.model_dump(mode="json"),
|
||||
}
|
||||
|
||||
|
||||
def _point_label_candidates(point: ResolvedPoint) -> list[str]:
|
||||
labels = [point.resolved_name]
|
||||
if point.input_name:
|
||||
labels.append(point.input_name)
|
||||
labels.append(point.input_address)
|
||||
deduplicated: list[str] = []
|
||||
for label in labels:
|
||||
if label not in deduplicated:
|
||||
deduplicated.append(label)
|
||||
return deduplicated
|
||||
|
||||
|
||||
def _ordered_points_for_best_route(
|
||||
request: RoutePlanRequest,
|
||||
result: RoutePlanResult,
|
||||
*,
|
||||
resolved_origin: ResolvedPoint | None,
|
||||
resolved_stops: list[ResolvedPoint],
|
||||
resolved_destination: ResolvedPoint,
|
||||
) -> list[ResolvedPoint]:
|
||||
ordered_points: list[ResolvedPoint] = []
|
||||
if request.origin_mode == "fixed" and resolved_origin is not None:
|
||||
ordered_points.append(resolved_origin)
|
||||
|
||||
remaining_stops = resolved_stops.copy()
|
||||
for label in result.best_route.stop_order_labels:
|
||||
normalized_label = _normalize_text(label)
|
||||
match_index = next(
|
||||
(
|
||||
index
|
||||
for index, point in enumerate(remaining_stops)
|
||||
if normalized_label in {_normalize_text(candidate) for candidate in _point_label_candidates(point)}
|
||||
),
|
||||
None,
|
||||
)
|
||||
if match_index is None:
|
||||
raise GuardrailError(f"Unable to map best_route stop label back to resolved stop: {label}")
|
||||
ordered_points.append(remaining_stops.pop(match_index))
|
||||
|
||||
ordered_points.append(resolved_destination)
|
||||
return ordered_points
|
||||
|
||||
|
||||
async def _build_required_deep_links(
|
||||
request: RoutePlanRequest,
|
||||
result: RoutePlanResult,
|
||||
*,
|
||||
resolved_origin: ResolvedPoint | None,
|
||||
resolved_stops: list[ResolvedPoint],
|
||||
resolved_destination: ResolvedPoint,
|
||||
) -> DeepLinks:
|
||||
if not request.need_deep_link:
|
||||
raise GuardrailError("need_deep_link must be true because this service requires deep link output")
|
||||
|
||||
ordered_points = _ordered_points_for_best_route(
|
||||
request,
|
||||
result,
|
||||
resolved_origin=resolved_origin,
|
||||
resolved_stops=resolved_stops,
|
||||
resolved_destination=resolved_destination,
|
||||
)
|
||||
|
||||
for point in ordered_points:
|
||||
if not point.poi_id:
|
||||
raise GuardrailError(
|
||||
f"Deep link generation requires poi_id for every point; missing poi_id for {point.input_address}"
|
||||
)
|
||||
|
||||
server = _build_mcp_server()
|
||||
async with server:
|
||||
personal_map = await _call_mcp_tool(
|
||||
server,
|
||||
"maps_schema_personal_map",
|
||||
{
|
||||
"orgName": "geo-agent",
|
||||
"lineList": [
|
||||
{
|
||||
"title": request.task_name,
|
||||
"pointInfoList": [
|
||||
{
|
||||
"name": point.resolved_name,
|
||||
"lon": point.lon,
|
||||
"lat": point.lat,
|
||||
"poiId": point.poi_id,
|
||||
}
|
||||
for point in ordered_points
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
if not isinstance(personal_map, str) or not personal_map.strip():
|
||||
raise GuardrailError("Deep link generation failed: maps_schema_personal_map returned no usable URI")
|
||||
|
||||
return DeepLinks(personal_map=personal_map.strip())
|
||||
|
||||
|
||||
def _configured_max_permutations() -> int:
|
||||
return _env_positive_int("ROUTE_MAX_PERMUTATIONS", 20)
|
||||
|
||||
|
||||
def _candidate_count(stop_count: int) -> int:
|
||||
return math.factorial(stop_count)
|
||||
|
||||
|
||||
def _prepare_request(request: RoutePlanRequest) -> RoutePlanRequest:
|
||||
if not request.need_deep_link:
|
||||
raise GuardrailError("need_deep_link must be true because this service requires deep link output")
|
||||
|
||||
configured_limit = _configured_max_permutations()
|
||||
effective_limit = request.max_permutations or configured_limit
|
||||
|
||||
if effective_limit > configured_limit:
|
||||
raise GuardrailError(
|
||||
"Requested max_permutations exceeds the configured service limit: "
|
||||
f"requested={effective_limit}, configured={configured_limit}"
|
||||
)
|
||||
|
||||
candidate_count = _candidate_count(len(request.stops))
|
||||
if candidate_count > effective_limit:
|
||||
raise GuardrailError(
|
||||
"Candidate permutations exceed the configured limit: "
|
||||
f"stops={len(request.stops)}, permutations={candidate_count}, limit={effective_limit}"
|
||||
)
|
||||
|
||||
return request.model_copy(update={"max_permutations": effective_limit})
|
||||
|
||||
|
||||
def _build_user_prompt(
|
||||
request: RoutePlanRequest,
|
||||
*,
|
||||
resolved_origin: ResolvedPoint | None,
|
||||
resolved_stops: list[ResolvedPoint],
|
||||
resolved_destination: ResolvedPoint,
|
||||
) -> str:
|
||||
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
|
||||
resolved_points_json = json.dumps(
|
||||
_resolved_points_payload(
|
||||
resolved_origin=resolved_origin,
|
||||
resolved_stops=resolved_stops,
|
||||
resolved_destination=resolved_destination,
|
||||
),
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return (
|
||||
"请根据下面的 RoutePlanRequest JSON 执行多目标路线规划。\n"
|
||||
"所有点位已经由服务端完成严格解析和 POI 校验,且这些点位是唯一可信输入。\n"
|
||||
"不要再调用 maps_geo、maps_text_search、maps_search_detail,也不要生成 deep link。\n"
|
||||
"你必须基于这些已解析点位,仅使用 maps_direction_driving 计算逐段路线,并选择 best_route。\n"
|
||||
"如果输入中存在固定起点,则完整路线必须从起点开始;如果起点模式是 current_location,则不得伪造起点坐标。\n"
|
||||
f"本次运行的候选顺序硬上限是 {request.max_permutations}。\n\n"
|
||||
"RoutePlanRequest JSON:\n"
|
||||
f"{request_json}\n\n"
|
||||
"Pre-resolved Points JSON:\n"
|
||||
f"{resolved_points_json}"
|
||||
)
|
||||
|
||||
|
||||
def _same_candidate(left_candidate, right_candidate) -> bool:
|
||||
return (
|
||||
left_candidate.full_order_labels == right_candidate.full_order_labels
|
||||
and left_candidate.total_distance_m == right_candidate.total_distance_m
|
||||
and left_candidate.total_duration_s == right_candidate.total_duration_s
|
||||
and len(left_candidate.legs) == len(right_candidate.legs)
|
||||
)
|
||||
|
||||
|
||||
def _validate_result(request: RoutePlanRequest, result: RoutePlanResult) -> RoutePlanResult:
|
||||
if not result.success:
|
||||
raise GuardrailError("Route planning did not complete successfully")
|
||||
|
||||
if result.origin_mode != request.origin_mode:
|
||||
raise GuardrailError("Agent output origin_mode does not match the request")
|
||||
|
||||
if result.resolved_destination.role != "destination":
|
||||
raise GuardrailError("resolved_destination.role must be 'destination'")
|
||||
|
||||
if len(result.resolved_stops) != len(request.stops):
|
||||
raise GuardrailError("resolved_stops count does not match the request")
|
||||
|
||||
if any(stop.role != "stop" for stop in result.resolved_stops):
|
||||
raise GuardrailError("All resolved_stops entries must have role='stop'")
|
||||
|
||||
if request.origin_mode == "fixed":
|
||||
if result.resolved_origin is None:
|
||||
raise GuardrailError("resolved_origin is required when origin_mode='fixed'")
|
||||
if result.resolved_origin.role != "origin":
|
||||
raise GuardrailError("resolved_origin.role must be 'origin'")
|
||||
else:
|
||||
if result.resolved_origin is not None:
|
||||
raise GuardrailError("resolved_origin must be null when origin_mode='current_location'")
|
||||
|
||||
if not result.candidates:
|
||||
raise GuardrailError("A successful result must include at least one candidate route")
|
||||
|
||||
if len(result.candidates) > (request.max_permutations or 0):
|
||||
raise GuardrailError("Agent returned more candidates than allowed by max_permutations")
|
||||
|
||||
matching_candidate = next(
|
||||
(
|
||||
candidate
|
||||
for candidate in result.candidates
|
||||
if _same_candidate(candidate, result.best_route)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_candidate is None:
|
||||
raise GuardrailError("best_route must be one of the candidates")
|
||||
|
||||
if result.deep_links is None:
|
||||
raise GuardrailError("A successful result must include deep_links")
|
||||
|
||||
if not any(
|
||||
[
|
||||
result.deep_links.personal_map,
|
||||
result.deep_links.android_route_plan,
|
||||
result.deep_links.ios_route_plan,
|
||||
]
|
||||
):
|
||||
raise GuardrailError("A successful result must include at least one deep link")
|
||||
|
||||
all_points = [*result.resolved_stops, result.resolved_destination]
|
||||
if result.resolved_origin is not None:
|
||||
all_points.append(result.resolved_origin)
|
||||
if any(not point.poi_id for point in all_points):
|
||||
raise GuardrailError("All resolved points must include poi_id on successful results")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_geo_agent() -> Agent[None, RoutePlanResult]:
|
||||
return Agent(
|
||||
model=_build_model(),
|
||||
toolsets=[_build_mcp_server()],
|
||||
output_type=RoutePlanResult,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def run_route_plan(request: RoutePlanRequest) -> RoutePlanResult:
|
||||
"""Execute the route planning agent for a given request."""
|
||||
prepared_request = _prepare_request(request)
|
||||
resolved_origin, resolved_stops, resolved_destination = await _resolve_request_points(prepared_request)
|
||||
agent = create_geo_agent()
|
||||
async with agent:
|
||||
result = await agent.run(
|
||||
_build_user_prompt(
|
||||
prepared_request,
|
||||
resolved_origin=resolved_origin,
|
||||
resolved_stops=resolved_stops,
|
||||
resolved_destination=resolved_destination,
|
||||
)
|
||||
)
|
||||
|
||||
output = result.output.model_copy(
|
||||
update={
|
||||
"resolved_origin": resolved_origin,
|
||||
"resolved_stops": resolved_stops,
|
||||
"resolved_destination": resolved_destination,
|
||||
}
|
||||
)
|
||||
output = output.model_copy(
|
||||
update={
|
||||
"deep_links": await _build_required_deep_links(
|
||||
prepared_request,
|
||||
output,
|
||||
resolved_origin=resolved_origin,
|
||||
resolved_stops=resolved_stops,
|
||||
resolved_destination=resolved_destination,
|
||||
)
|
||||
}
|
||||
)
|
||||
return _validate_result(prepared_request, output)
|
||||
Reference in New Issue
Block a user