feat: auth
This commit is contained in:
107
agent/load_plan.py
Normal file
107
agent/load_plan.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai_skills import SkillsToolset
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic_ai.models.openai import OpenAIModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
|
||||
from schemas import LoadPlanRequest, LoadPlanResult
|
||||
from .load_plan_tools import (
|
||||
get_transport_vehicle_by_license_plate,
|
||||
list_unshipped_shipments,
|
||||
)
|
||||
|
||||
|
||||
class LoadPlanConfigurationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class LoadPlanGuardrailError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _required_env(name: str) -> str:
|
||||
import os
|
||||
|
||||
value = os.getenv(name, "").strip()
|
||||
if not value:
|
||||
raise LoadPlanConfigurationError(f"Missing required environment variable: {name}")
|
||||
return value
|
||||
|
||||
|
||||
def _env_positive_float(name: str, default: float) -> float:
|
||||
import os
|
||||
|
||||
raw_value = os.getenv(name, str(default)).strip()
|
||||
try:
|
||||
parsed = float(raw_value)
|
||||
except ValueError as exc:
|
||||
raise LoadPlanConfigurationError(f"Environment variable {name} must be a number") from exc
|
||||
|
||||
if parsed <= 0:
|
||||
raise LoadPlanConfigurationError(f"Environment variable {name} must be greater than 0")
|
||||
return parsed
|
||||
|
||||
|
||||
def _build_model() -> OpenAIModel:
|
||||
openai_client = AsyncOpenAI(
|
||||
base_url=_required_env("LOAD_PLAN_BASE_URL"),
|
||||
api_key=_required_env("LOAD_PLAN_API_KEY"),
|
||||
timeout=_env_positive_float("LOAD_PLAN_REQUEST_TIMEOUT_SECONDS", 120.0),
|
||||
)
|
||||
provider = OpenAIProvider(
|
||||
openai_client=openai_client,
|
||||
)
|
||||
return OpenAIModel(
|
||||
_required_env("LOAD_PLAN_MODEL"),
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
INSTRUCTIONS = (
|
||||
"你是一个装载规划 Agent。"
|
||||
"你的任务是根据请求中的 merchant_id、area 和 license_plate,"
|
||||
"调用工具获取待出货出货单与车辆容量,并严格遵守已加载的 skill 来选择出货单。"
|
||||
"你必须先获取出货单,再获取车辆容量。"
|
||||
"车辆容量按销售品条数计算,不按 quantity 米数计算。"
|
||||
"目标是在不超过车辆容量的前提下尽量多装;若无法满载,允许欠载,但必须给出 warning。"
|
||||
"不要臆造 sales_items、printing_job_width、面料类型、数量或容量。"
|
||||
"如果无法依据真实返回判断装载结果,就返回空的 selected_shipment_ids,并在 summary/warnings 中说明。"
|
||||
)
|
||||
|
||||
|
||||
def _build_user_prompt(request: LoadPlanRequest) -> str:
|
||||
request_json = json.dumps(request.model_dump(mode="json"), ensure_ascii=False, indent=2)
|
||||
return (
|
||||
"请根据下面的 LoadPlanRequest JSON 执行装载规划。\n"
|
||||
"你必须调用工具获取真实数据,并根据 skill 中的规则完成判断。\n\n"
|
||||
"LoadPlanRequest JSON:\n"
|
||||
f"{request_json}"
|
||||
)
|
||||
|
||||
|
||||
def create_load_plan_agent() -> Agent[None, LoadPlanResult]:
|
||||
return Agent(
|
||||
model=_build_model(),
|
||||
instructions=INSTRUCTIONS,
|
||||
output_type=LoadPlanResult,
|
||||
tools=[
|
||||
get_transport_vehicle_by_license_plate,
|
||||
list_unshipped_shipments,
|
||||
],
|
||||
toolsets=[
|
||||
SkillsToolset(
|
||||
directories=["./skills"],
|
||||
auto_reload=True,
|
||||
exclude_tools={"run_skill_script"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def run_load_plan(request: LoadPlanRequest) -> LoadPlanResult:
|
||||
agent = create_load_plan_agent()
|
||||
async with agent:
|
||||
result = await agent.run(_build_user_prompt(request))
|
||||
return result.output
|
||||
Reference in New Issue
Block a user