Files
geo-setp/agent/load_plan.py
2026-04-15 14:05:33 +08:00

108 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from pydantic_ai import Agent
from pydantic_ai_skills import SkillsToolset
from openai import AsyncOpenAI
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from schemas import LoadPlanRequest, LoadPlanResult
from .load_plan_tools import (
get_transport_vehicle_by_license_plate,
list_unshipped_shipments,
)
class LoadPlanConfigurationError(RuntimeError):
pass
class LoadPlanGuardrailError(RuntimeError):
pass
def _required_env(name: str) -> str:
import os
value = os.getenv(name, "").strip()
if not value:
raise LoadPlanConfigurationError(f"Missing required environment variable: {name}")
return value
def _env_positive_float(name: str, default: float) -> float:
import os
raw_value = os.getenv(name, str(default)).strip()
try:
parsed = float(raw_value)
except ValueError as exc:
raise LoadPlanConfigurationError(f"Environment variable {name} must be a number") from exc
if parsed <= 0:
raise LoadPlanConfigurationError(f"Environment variable {name} must be greater than 0")
return parsed
def _build_model() -> OpenAIModel:
openai_client = AsyncOpenAI(
base_url=_required_env("LOAD_PLAN_BASE_URL"),
api_key=_required_env("LOAD_PLAN_API_KEY"),
timeout=_env_positive_float("LOAD_PLAN_REQUEST_TIMEOUT_SECONDS", 120.0),
)
provider = OpenAIProvider(
openai_client=openai_client,
)
return OpenAIModel(
_required_env("LOAD_PLAN_MODEL"),
provider=provider,
)
INSTRUCTIONS = (
"你是一个装载规划 Agent。"
"你的任务是根据请求中的 merchant_id、area 和 license_plate"
"调用工具获取待出货出货单与车辆容量,并严格遵守已加载的 skill 来选择出货单。"
"你必须先获取出货单,再获取车辆容量。"
"车辆容量按销售品条数计算,不按 quantity 米数计算。"
"目标是在不超过车辆容量的前提下尽量多装;若无法满载,允许欠载,但必须给出 warning。"
"不要臆造 sales_items、printing_job_width、面料类型、数量或容量。"
"如果无法依据真实返回判断装载结果,就返回空的 selected_shipment_ids并在 summary/warnings 中说明。"
)
def _build_user_prompt(request: LoadPlanRequest) -> str:
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
return (
"请根据下面的 LoadPlanRequest JSON 执行装载规划。\n"
"你必须调用工具获取真实数据,并根据 skill 中的规则完成判断。\n\n"
"LoadPlanRequest JSON:\n"
f"{request_json}"
)
def create_load_plan_agent() -> Agent[None, LoadPlanResult]:
return Agent(
model=_build_model(),
instructions=INSTRUCTIONS,
output_type=LoadPlanResult,
tools=[
get_transport_vehicle_by_license_plate,
list_unshipped_shipments,
],
toolsets=[
SkillsToolset(
directories=["./skills"],
auto_reload=True,
exclude_tools={"run_skill_script"},
)
],
)
async def run_load_plan(request: LoadPlanRequest) -> LoadPlanResult:
agent = create_load_plan_agent()
async with agent:
result = await agent.run(_build_user_prompt(request))
return result.output