feat: expand AI workflow support and refresh docs
This commit is contained in:
245
Server/app/api/v1/ai_identify.py
Normal file
245
Server/app/api/v1/ai_identify.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AI 裁片识别接口
|
||||
功能:识别裁片部位、分析颜色花样、验证套图结果
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import json, logging
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_current_user
|
||||
from app.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user import User
|
||||
from app.api.v1.ai_llm import (
|
||||
call_vision_messages,
|
||||
get_runtime_ai_config,
|
||||
verify_pattern_result,
|
||||
has_ai_config,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
if not log.handlers:
|
||||
log.setLevel(logging.INFO)
|
||||
log.propagate = False
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(
|
||||
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
|
||||
)
|
||||
log.addHandler(_h)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
|
||||
class IdentifyPiecesRequest(BaseModel):
|
||||
canvas_base64: str
|
||||
garment_base64: Optional[str] = None
|
||||
layers_info: str
|
||||
vision_model: Optional[str] = None
|
||||
|
||||
|
||||
class VerifyResultRequest(BaseModel):
|
||||
garment_base64: str
|
||||
canvas_base64: str
|
||||
prompt: Optional[str] = None
|
||||
vision_model: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 裁片识别接口 ====================
|
||||
|
||||
|
||||
@router.post("/ai/identify-pieces")
|
||||
async def identify_pieces(
|
||||
data: IdentifyPiecesRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""用视觉模型识别裁片部位 + 分析成衣各部位的颜色花样"""
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_vision_model = (
|
||||
data.vision_model
|
||||
or (user.ai_vision_model if user else None)
|
||||
or settings.AI_VISION_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config):
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
try:
|
||||
result = _identify_garment_pieces(
|
||||
data.canvas_base64,
|
||||
data.layers_info,
|
||||
data.garment_base64,
|
||||
vision_model=effective_vision_model,
|
||||
runtime_ai_config=runtime_ai_config,
|
||||
)
|
||||
return {"code": 200, "data": result}
|
||||
except Exception as e:
|
||||
log.error(f"裁片识别失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"裁片识别失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ai/verify-result")
|
||||
async def verify_result(
|
||||
data: VerifyResultRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""对比原始成衣和套图结果,给出验证反馈"""
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_vision_model = (
|
||||
data.vision_model
|
||||
or (user.ai_vision_model if user else None)
|
||||
or settings.AI_VISION_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config):
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
try:
|
||||
feedback = verify_pattern_result(
|
||||
data.garment_base64,
|
||||
data.canvas_base64,
|
||||
data.prompt,
|
||||
vision_model=effective_vision_model,
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
return {"code": 200, "data": {"feedback": feedback}}
|
||||
except Exception as e:
|
||||
log.error(f"验证失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"验证失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 核心识别逻辑 ====================
|
||||
|
||||
|
||||
def _identify_garment_pieces(
|
||||
canvas_b64: str,
|
||||
layers_info: str,
|
||||
garment_b64: str = None,
|
||||
vision_model: str = None,
|
||||
runtime_ai_config=None,
|
||||
) -> dict:
|
||||
"""用视觉模型分析画布+成衣,识别裁片部位并判断颜色/花样"""
|
||||
use_model = vision_model or settings.AI_VISION_MODEL
|
||||
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(f"[IdentifyPieces] 调用视觉模型识别裁片 + 分析花样")
|
||||
log.info(f"[IdentifyPieces] 模型: {use_model}")
|
||||
log.info(
|
||||
f"[IdentifyPieces] 画布: {len(canvas_b64) // 1024}KB, 成衣: {len(garment_b64) // 1024 if garment_b64 else 0}KB"
|
||||
)
|
||||
log.info(f"[IdentifyPieces] 图层:\n{layers_info[:500]}")
|
||||
|
||||
layer_names = []
|
||||
for line in layers_info.strip().split("\n"):
|
||||
name = line.split("(")[0].strip()
|
||||
if name:
|
||||
layer_names.append(name)
|
||||
|
||||
ex1 = layer_names[0] if len(layer_names) > 0 else "图层名1"
|
||||
ex2 = layer_names[1] if len(layer_names) > 1 else "图层名2"
|
||||
ex3 = layer_names[2] if len(layer_names) > 2 else "图层名3"
|
||||
|
||||
if garment_b64:
|
||||
prompt = f"""你需要完成三个任务:
|
||||
|
||||
**任务1:识别裁片部位**
|
||||
Image 2 是 PS 文档中的裁片轮廓图。以下是图层信息:
|
||||
{layers_info}
|
||||
|
||||
请根据形状、大小识别每个**图层组**是什么服装部位(前片、后片、袖子、领口等)。
|
||||
⚠️ mapping 的 key 必须是上面图层信息中的**实际图层组名称**(如 "{ex1}"、"{ex2}" 等),不要自己编名字!
|
||||
|
||||
**任务2:判断花型覆盖模式 pattern_mode**
|
||||
Image 1 是成衣照片。请判断花型属于哪种模式:
|
||||
- "all_over":所有裁片来自同一块面料(碎花/格子/渐变等整体花型覆盖全身)
|
||||
- "placement":各裁片独立处理(有的印花有的纯色,如卡通卫衣)
|
||||
- "color_block":全部纯色拼接
|
||||
|
||||
如果是 all_over,还需提供 fabric_info:
|
||||
- fabric_type: 面料印花类型
|
||||
- pattern_direction: 花型方向
|
||||
- pattern_elements: 花型元素描述
|
||||
- color_transition: 色彩过渡描述
|
||||
- density: 花型密度描述
|
||||
|
||||
**任务3:分析每个裁片**
|
||||
- type: solid / fill_pattern / theme_pattern / mixed_pattern / all_over
|
||||
- solid 和 theme_pattern 必须给 color(hex值)
|
||||
- 如果某个裁片在照片中**看不到**(如后片),设置 visible: false 并推理
|
||||
- 左右对称片标记 symmetric_pairs
|
||||
|
||||
用 JSON 回答(key 用实际图层名):
|
||||
```json
|
||||
{{
|
||||
"pattern_mode": "all_over",
|
||||
"fabric_info": {{
|
||||
"fabric_type": "数码印花",
|
||||
"pattern_direction": "上下渐变",
|
||||
"pattern_elements": "白色百合花朵 + 灰白格纹底",
|
||||
"color_transition": "从黑色渐变到灰白色",
|
||||
"density": "上部花朵密集大朵,下部稀疏小朵"
|
||||
}},
|
||||
"mapping": {{
|
||||
"{ex1}": "左袖",
|
||||
"{ex2}": "前片"
|
||||
}},
|
||||
"analysis": [
|
||||
{{"layer": "{ex1}", "piece": "左袖", "type": "all_over", "visible": true, "description": "黑底白花渐变"}},
|
||||
{{"layer": "{ex2}", "piece": "前片", "type": "all_over", "visible": true, "description": "上半黑底大花+下半灰白格纹"}},
|
||||
{{"layer": "{ex3}", "piece": "后片", "type": "all_over", "visible": false, "inferred_from": "{ex2}", "inference_reason": "全身花型面料,后片花型与前片相同", "confidence": "high"}}
|
||||
],
|
||||
"symmetric_pairs": [["{ex1}", "对称片名"]],
|
||||
"base_color": "#000000"
|
||||
}}
|
||||
```
|
||||
|
||||
⚠️ 重要:
|
||||
- mapping 和 analysis 中的 layer 必须用**实际图层组名称**
|
||||
- 看不到的部位设 visible: false 并说明推理依据
|
||||
- 左右对称的片标记 symmetric_pairs
|
||||
只返回 JSON,不要其他文字。"""
|
||||
else:
|
||||
prompt = f"""这是一张 PS 文档裁片截图。图层信息:
|
||||
{layers_info}
|
||||
|
||||
识别每个图层组的服装部位,mapping 的 key 必须用**实际图层名**:
|
||||
```json
|
||||
{{"mapping": {{"{ex1}": "前片", "{ex2}": "后片"}}}}
|
||||
```
|
||||
只返回 JSON。"""
|
||||
|
||||
log.info(f"[IdentifyPieces] 使用豆包视觉: {use_model}")
|
||||
images = []
|
||||
if garment_b64:
|
||||
images.append(garment_b64)
|
||||
images.append(canvas_b64)
|
||||
content = call_vision_messages(
|
||||
prompt,
|
||||
images,
|
||||
model_override=use_model,
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
log.info(f"[IdentifyPieces] 视觉模型回复: {content[:500]}")
|
||||
|
||||
try:
|
||||
clean = content.strip()
|
||||
if "```" in clean:
|
||||
clean = clean.split("```")[1]
|
||||
if clean.startswith("json"):
|
||||
clean = clean[4:]
|
||||
clean = clean.strip()
|
||||
parsed = json.loads(clean)
|
||||
|
||||
if "mapping" in parsed:
|
||||
return parsed
|
||||
else:
|
||||
return {"mapping": parsed}
|
||||
except json.JSONDecodeError:
|
||||
log.warning(f"[IdentifyPieces] JSON 解析失败")
|
||||
return {"mapping": {}, "raw_response": content}
|
||||
Reference in New Issue
Block a user