108 lines
3.4 KiB
Python
108 lines
3.4 KiB
Python
import json
|
||
|
||
from pydantic_ai import Agent
|
||
from pydantic_ai_skills import SkillsToolset
|
||
from openai import AsyncOpenAI
|
||
from pydantic_ai.models.openai import OpenAIModel
|
||
from pydantic_ai.providers.openai import OpenAIProvider
|
||
|
||
from schemas import LoadPlanRequest, LoadPlanResult
|
||
from .load_plan_tools import (
|
||
get_transport_vehicle_by_license_plate,
|
||
list_unshipped_shipments,
|
||
)
|
||
|
||
|
||
class LoadPlanConfigurationError(RuntimeError):
|
||
pass
|
||
|
||
|
||
class LoadPlanGuardrailError(RuntimeError):
|
||
pass
|
||
|
||
|
||
def _required_env(name: str) -> str:
|
||
import os
|
||
|
||
value = os.getenv(name, "").strip()
|
||
if not value:
|
||
raise LoadPlanConfigurationError(f"Missing required environment variable: {name}")
|
||
return value
|
||
|
||
|
||
def _env_positive_float(name: str, default: float) -> float:
|
||
import os
|
||
|
||
raw_value = os.getenv(name, str(default)).strip()
|
||
try:
|
||
parsed = float(raw_value)
|
||
except ValueError as exc:
|
||
raise LoadPlanConfigurationError(f"Environment variable {name} must be a number") from exc
|
||
|
||
if parsed <= 0:
|
||
raise LoadPlanConfigurationError(f"Environment variable {name} must be greater than 0")
|
||
return parsed
|
||
|
||
|
||
def _build_model() -> OpenAIModel:
|
||
openai_client = AsyncOpenAI(
|
||
base_url=_required_env("LOAD_PLAN_BASE_URL"),
|
||
api_key=_required_env("LOAD_PLAN_API_KEY"),
|
||
timeout=_env_positive_float("LOAD_PLAN_REQUEST_TIMEOUT_SECONDS", 120.0),
|
||
)
|
||
provider = OpenAIProvider(
|
||
openai_client=openai_client,
|
||
)
|
||
return OpenAIModel(
|
||
_required_env("LOAD_PLAN_MODEL"),
|
||
provider=provider,
|
||
)
|
||
|
||
|
||
INSTRUCTIONS = (
|
||
"你是一个装载规划 Agent。"
|
||
"你的任务是根据请求中的 merchant_id、area 和 license_plate,"
|
||
"调用工具获取待出货出货单与车辆容量,并严格遵守已加载的 skill 来选择出货单。"
|
||
"你必须先获取出货单,再获取车辆容量。"
|
||
"车辆容量按销售品条数计算,不按 quantity 米数计算。"
|
||
"目标是在不超过车辆容量的前提下尽量多装;若无法满载,允许欠载,但必须给出 warning。"
|
||
"不要臆造 sales_items、printing_job_width、面料类型、数量或容量。"
|
||
"如果无法依据真实返回判断装载结果,就返回空的 selected_shipment_ids,并在 summary/warnings 中说明。"
|
||
)
|
||
|
||
|
||
def _build_user_prompt(request: LoadPlanRequest) -> str:
|
||
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
|
||
return (
|
||
"请根据下面的 LoadPlanRequest JSON 执行装载规划。\n"
|
||
"你必须调用工具获取真实数据,并根据 skill 中的规则完成判断。\n\n"
|
||
"LoadPlanRequest JSON:\n"
|
||
f"{request_json}"
|
||
)
|
||
|
||
|
||
def create_load_plan_agent() -> Agent[None, LoadPlanResult]:
|
||
return Agent(
|
||
model=_build_model(),
|
||
instructions=INSTRUCTIONS,
|
||
output_type=LoadPlanResult,
|
||
tools=[
|
||
get_transport_vehicle_by_license_plate,
|
||
list_unshipped_shipments,
|
||
],
|
||
toolsets=[
|
||
SkillsToolset(
|
||
directories=["./skills"],
|
||
auto_reload=True,
|
||
exclude_tools={"run_skill_script"},
|
||
)
|
||
],
|
||
)
|
||
|
||
|
||
async def run_load_plan(request: LoadPlanRequest) -> LoadPlanResult:
|
||
agent = create_load_plan_agent()
|
||
async with agent:
|
||
result = await agent.run(_build_user_prompt(request))
|
||
return result.output
|