Files
DP/Server/app/api/v1/ai_identify.py

246 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
AI 裁片识别接口
功能:识别裁片部位、分析颜色花样、验证套图结果
"""
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from typing import Optional
import json, logging
from app.core.config import settings
from app.core.security import get_current_user
from app.db import get_db
from sqlalchemy.orm import Session
from app.models.user import User
from app.api.v1.ai_llm import (
call_vision_messages,
get_runtime_ai_config,
verify_pattern_result,
has_ai_config,
)
log = logging.getLogger(__name__)
if not log.handlers:
log.setLevel(logging.INFO)
log.propagate = False
_h = logging.StreamHandler()
_h.setFormatter(
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
)
log.addHandler(_h)
router = APIRouter()
# ==================== 数据模型 ====================
class IdentifyPiecesRequest(BaseModel):
canvas_base64: str
garment_base64: Optional[str] = None
layers_info: str
vision_model: Optional[str] = None
class VerifyResultRequest(BaseModel):
garment_base64: str
canvas_base64: str
prompt: Optional[str] = None
vision_model: Optional[str] = None
# ==================== 裁片识别接口 ====================
@router.post("/ai/identify-pieces")
async def identify_pieces(
data: IdentifyPiecesRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""用视觉模型识别裁片部位 + 分析成衣各部位的颜色花样"""
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_vision_model = (
data.vision_model
or (user.ai_vision_model if user else None)
or settings.AI_VISION_MODEL
)
if not has_ai_config(runtime_ai_config):
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
try:
result = _identify_garment_pieces(
data.canvas_base64,
data.layers_info,
data.garment_base64,
vision_model=effective_vision_model,
runtime_ai_config=runtime_ai_config,
)
return {"code": 200, "data": result}
except Exception as e:
log.error(f"裁片识别失败: {e}", exc_info=True)
raise HTTPException(500, f"裁片识别失败: {str(e)}")
@router.post("/ai/verify-result")
async def verify_result(
data: VerifyResultRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""对比原始成衣和套图结果,给出验证反馈"""
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_vision_model = (
data.vision_model
or (user.ai_vision_model if user else None)
or settings.AI_VISION_MODEL
)
if not has_ai_config(runtime_ai_config):
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
try:
feedback = verify_pattern_result(
data.garment_base64,
data.canvas_base64,
data.prompt,
vision_model=effective_vision_model,
runtime_config=runtime_ai_config,
)
return {"code": 200, "data": {"feedback": feedback}}
except Exception as e:
log.error(f"验证失败: {e}", exc_info=True)
raise HTTPException(500, f"验证失败: {str(e)}")
# ==================== 核心识别逻辑 ====================
def _identify_garment_pieces(
canvas_b64: str,
layers_info: str,
garment_b64: str = None,
vision_model: str = None,
runtime_ai_config=None,
) -> dict:
"""用视觉模型分析画布+成衣,识别裁片部位并判断颜色/花样"""
use_model = vision_model or settings.AI_VISION_MODEL
log.info(f"{'=' * 60}")
log.info(f"[IdentifyPieces] 调用视觉模型识别裁片 + 分析花样")
log.info(f"[IdentifyPieces] 模型: {use_model}")
log.info(
f"[IdentifyPieces] 画布: {len(canvas_b64) // 1024}KB, 成衣: {len(garment_b64) // 1024 if garment_b64 else 0}KB"
)
log.info(f"[IdentifyPieces] 图层:\n{layers_info[:500]}")
layer_names = []
for line in layers_info.strip().split("\n"):
name = line.split("(")[0].strip()
if name:
layer_names.append(name)
ex1 = layer_names[0] if len(layer_names) > 0 else "图层名1"
ex2 = layer_names[1] if len(layer_names) > 1 else "图层名2"
ex3 = layer_names[2] if len(layer_names) > 2 else "图层名3"
if garment_b64:
prompt = f"""你需要完成三个任务:
**任务1识别裁片部位**
Image 2 是 PS 文档中的裁片轮廓图。以下是图层信息:
{layers_info}
请根据形状、大小识别每个**图层组**是什么服装部位(前片、后片、袖子、领口等)。
⚠️ mapping 的 key 必须是上面图层信息中的**实际图层组名称**(如 "{ex1}""{ex2}" 等),不要自己编名字!
**任务2判断花型覆盖模式 pattern_mode**
Image 1 是成衣照片。请判断花型属于哪种模式:
- "all_over":所有裁片来自同一块面料(碎花/格子/渐变等整体花型覆盖全身)
- "placement":各裁片独立处理(有的印花有的纯色,如卡通卫衣)
- "color_block":全部纯色拼接
如果是 all_over还需提供 fabric_info
- fabric_type: 面料印花类型
- pattern_direction: 花型方向
- pattern_elements: 花型元素描述
- color_transition: 色彩过渡描述
- density: 花型密度描述
**任务3分析每个裁片**
- type: solid / fill_pattern / theme_pattern / mixed_pattern / all_over
- solid 和 theme_pattern 必须给 colorhex值
- 如果某个裁片在照片中**看不到**(如后片),设置 visible: false 并推理
- 左右对称片标记 symmetric_pairs
用 JSON 回答key 用实际图层名):
```json
{{
"pattern_mode": "all_over",
"fabric_info": {{
"fabric_type": "数码印花",
"pattern_direction": "上下渐变",
"pattern_elements": "白色百合花朵 + 灰白格纹底",
"color_transition": "从黑色渐变到灰白色",
"density": "上部花朵密集大朵,下部稀疏小朵"
}},
"mapping": {{
"{ex1}": "左袖",
"{ex2}": "前片"
}},
"analysis": [
{{"layer": "{ex1}", "piece": "左袖", "type": "all_over", "visible": true, "description": "黑底白花渐变"}},
{{"layer": "{ex2}", "piece": "前片", "type": "all_over", "visible": true, "description": "上半黑底大花+下半灰白格纹"}},
{{"layer": "{ex3}", "piece": "后片", "type": "all_over", "visible": false, "inferred_from": "{ex2}", "inference_reason": "全身花型面料,后片花型与前片相同", "confidence": "high"}}
],
"symmetric_pairs": [["{ex1}", "对称片名"]],
"base_color": "#000000"
}}
```
⚠️ 重要:
- mapping 和 analysis 中的 layer 必须用**实际图层组名称**
- 看不到的部位设 visible: false 并说明推理依据
- 左右对称的片标记 symmetric_pairs
只返回 JSON不要其他文字。"""
else:
prompt = f"""这是一张 PS 文档裁片截图。图层信息:
{layers_info}
识别每个图层组的服装部位mapping 的 key 必须用**实际图层名**
```json
{{"mapping": {{"{ex1}": "前片", "{ex2}": "后片"}}}}
```
只返回 JSON。"""
log.info(f"[IdentifyPieces] 使用豆包视觉: {use_model}")
images = []
if garment_b64:
images.append(garment_b64)
images.append(canvas_b64)
content = call_vision_messages(
prompt,
images,
model_override=use_model,
runtime_config=runtime_ai_config,
)
log.info(f"[IdentifyPieces] 视觉模型回复: {content[:500]}")
try:
clean = content.strip()
if "```" in clean:
clean = clean.split("```")[1]
if clean.startswith("json"):
clean = clean[4:]
clean = clean.strip()
parsed = json.loads(clean)
if "mapping" in parsed:
return parsed
else:
return {"mapping": parsed}
except json.JSONDecodeError:
log.warning(f"[IdentifyPieces] JSON 解析失败")
return {"mapping": {}, "raw_response": content}