refactor: 改用 pydantic-ai 框架调用豆包,结构化输出 PlateOrder
This commit is contained in:
96
app.py
96
app.py
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
开版订单 AI 解析服务
|
开版订单 AI 解析服务
|
||||||
Flask + 豆包 API(/api/v3/responses 格式)
|
Flask + pydantic-ai + 豆包(方舟 OpenAI 兼容协议 /chat/completions)
|
||||||
|
|
||||||
功能:
|
功能:
|
||||||
1. 代理企微消息接口(解决 CORS)
|
1. 代理企微消息接口(解决 CORS)
|
||||||
@@ -17,12 +17,19 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from flask import Flask, request, jsonify
|
from flask import Flask, request, jsonify
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from pydantic_ai import Agent
|
||||||
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# ─── 日志 ─────────────────────────────────────────────────────────────────────
|
# ─── 日志 ─────────────────────────────────────────────────────────────────────
|
||||||
@@ -36,15 +43,10 @@ logger = logging.getLogger(__name__)
|
|||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app, origins="*")
|
CORS(app, origins="*")
|
||||||
|
|
||||||
# ─── 豆包 AI 配置 ─────────────────────────────────────────────────────────────
|
# ─── 豆包 AI 配置(通过 pydantic-ai + OpenAI 兼容协议调用方舟)─────────────────
|
||||||
AI_API_KEY = os.getenv("AI_API_KEY", "")
|
AI_API_KEY = os.getenv("AI_API_KEY", "")
|
||||||
AI_BASE_URL = os.getenv("AI_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3")
|
AI_BASE_URL = os.getenv("AI_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3")
|
||||||
AI_MODEL = os.getenv("AI_MODEL", "doubao-seed-2-0-mini-260428")
|
AI_MODEL = os.getenv("AI_MODEL", "doubao-seed-2-0-mini-260428")
|
||||||
DOUBAO_URL = f"{AI_BASE_URL}/responses"
|
|
||||||
DOUBAO_HEADERS = {
|
|
||||||
"Authorization": f"Bearer {AI_API_KEY}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
# ─── 企微消息服务配置 ─────────────────────────────────────────────────────────
|
# ─── 企微消息服务配置 ─────────────────────────────────────────────────────────
|
||||||
WECHAT_API_BASE = "https://test.ruicaiyinhua.online"
|
WECHAT_API_BASE = "https://test.ruicaiyinhua.online"
|
||||||
@@ -331,37 +333,57 @@ def build_system_prompt() -> str:
|
|||||||
3. 一律返回 JSON 对象,不要任何其它文字"""
|
3. 一律返回 JSON 对象,不要任何其它文字"""
|
||||||
|
|
||||||
|
|
||||||
|
class PlateOrder(BaseModel):
|
||||||
|
"""开版订单解析结果。所有字段均可为 null(无法判断时返回 None)。"""
|
||||||
|
customer_name: Optional[str] = Field(None, description="客户名称(自由文本,消息开头)")
|
||||||
|
plate_type: Optional[str] = Field(None, description="起版情况枚举")
|
||||||
|
urgency_level: Optional[str] = Field(None, description="紧急程度枚举")
|
||||||
|
production_method: Optional[str] = Field(None, description="做货方式枚举")
|
||||||
|
plate_method: Optional[str] = Field(None, description="开版方式枚举")
|
||||||
|
area: Optional[str] = Field(None, description="客户区域枚举")
|
||||||
|
fabric: Optional[str] = Field(None, description="面料名称(自由文本,保留克重与布类)")
|
||||||
|
width: Optional[str] = Field(None, description="面料幅宽枚举")
|
||||||
|
fabric_source: Optional[str] = Field(None, description="面料来源枚举")
|
||||||
|
drawing_rating: Optional[str] = Field(None, description="画图评级枚举")
|
||||||
|
color_matching_rating: Optional[str] = Field(None, description="调色评级枚举")
|
||||||
|
sample_rating: Optional[str] = Field(None, description="套样评级枚举")
|
||||||
|
difficulty_rating: Optional[str] = Field(None, description="难度评级枚举")
|
||||||
|
process_name: Optional[str] = Field(None, description="关联流程名称")
|
||||||
|
plate_notes: Optional[str] = Field(None, description="打版备注(自由文本,原文保留)")
|
||||||
|
style_name: Optional[str] = Field(None, description="款号名称")
|
||||||
|
original_design_code: Optional[str] = Field(None, description="原订单ID(5-7位纯数字)")
|
||||||
|
merchandiser_name: Optional[str] = Field(None, description="跟单员(取 sender 原值)")
|
||||||
|
|
||||||
|
@field_validator("*", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_empty(cls, v):
|
||||||
|
"""模型有时会填字符串 'null'/'none'/空串,统一归一为 None"""
|
||||||
|
if isinstance(v, str) and v.strip().lower() in ("", "null", "none", "n/a", "无"):
|
||||||
|
return None
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
# pydantic-ai Agent:通过 OpenAI 兼容协议调用方舟(豆包)
|
||||||
|
_provider = OpenAIProvider(base_url=AI_BASE_URL, api_key=AI_API_KEY)
|
||||||
|
_model = OpenAIChatModel(AI_MODEL, provider=_provider)
|
||||||
|
plate_agent = Agent(
|
||||||
|
_model,
|
||||||
|
output_type=PlateOrder,
|
||||||
|
model_settings={"temperature": 0.1, "max_tokens": 4096},
|
||||||
|
retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@plate_agent.system_prompt
|
||||||
|
def _dynamic_system_prompt() -> str:
|
||||||
|
"""系统提示词依赖运行时加载的字典缓存,故每次调用动态生成"""
|
||||||
|
return build_system_prompt()
|
||||||
|
|
||||||
|
|
||||||
def call_doubao(user_text: str) -> dict:
|
def call_doubao(user_text: str) -> dict:
|
||||||
"""调用豆包 /v3/responses,返回解析后的 dict"""
|
"""通过 pydantic-ai 调用豆包,返回结构化结果 dict"""
|
||||||
payload = {
|
result = plate_agent.run_sync(user_text)
|
||||||
"model": AI_MODEL,
|
return result.output.model_dump()
|
||||||
"input": [
|
|
||||||
{"role": "system", "content": [{"type": "input_text", "text": build_system_prompt()}]},
|
|
||||||
{"role": "user", "content": [{"type": "input_text", "text": user_text}]},
|
|
||||||
],
|
|
||||||
"temperature": 0.1,
|
|
||||||
"max_output_tokens": 4096,
|
|
||||||
}
|
|
||||||
r = requests.post(DOUBAO_URL, headers=DOUBAO_HEADERS, json=payload, timeout=60)
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
|
|
||||||
raw_text = ""
|
|
||||||
for item in data.get("output", []):
|
|
||||||
if item.get("type") == "message":
|
|
||||||
for c in item.get("content", []):
|
|
||||||
if c.get("type") == "output_text":
|
|
||||||
raw_text = c.get("text", "").strip()
|
|
||||||
break
|
|
||||||
break
|
|
||||||
if not raw_text:
|
|
||||||
raise ValueError("豆包返回内容为空")
|
|
||||||
|
|
||||||
start = raw_text.find("{")
|
|
||||||
end = raw_text.rfind("}") + 1
|
|
||||||
if start == -1 or end == 0:
|
|
||||||
raise ValueError(f"响应中未找到 JSON:{raw_text}")
|
|
||||||
return json.loads(raw_text[start:end])
|
|
||||||
|
|
||||||
|
|
||||||
# ─── 路由 ─────────────────────────────────────────────────────────────────────
|
# ─── 路由 ─────────────────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -2,3 +2,4 @@ flask==3.0.3
|
|||||||
flask-cors==4.0.1
|
flask-cors==4.0.1
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
pydantic-ai-slim[openai]==1.107.0
|
||||||
|
|||||||
Reference in New Issue
Block a user