734 lines
30 KiB
Python
734 lines
30 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 聊天接口
|
||
功能:对话管理 + 消息存库 + function calling(工具调用)
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
import json, base64, re, logging
|
||
from sqlalchemy.orm import Session as DBSession
|
||
from app.core.config import settings
|
||
from app.core.security import get_current_user
|
||
from app.db import get_db
|
||
from app.models.user import User
|
||
from app.models.chat import ChatSession, ChatMessage
|
||
from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES
|
||
from app.api.v1.ai_llm import (
|
||
SYSTEM_PROMPT, VISION_PROMPT,
|
||
is_gemini_model, call_llm_with_tools, call_gemini_with_tools,
|
||
call_vision_llm, call_gemini, call_image_model,
|
||
verify_pattern_result, mock_reply,
|
||
)
|
||
|
||
log = logging.getLogger(__name__)
|
||
# 防止 uvicorn --reload 导致 handler 重复
|
||
if not log.handlers:
|
||
log.setLevel(logging.INFO)
|
||
log.propagate = False
|
||
_h = logging.StreamHandler()
|
||
_h.setFormatter(logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] %(message)s'))
|
||
log.addHandler(_h)
|
||
|
||
router = APIRouter()
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
session_id: Optional[int] = None
|
||
image_base64: Optional[str] = None
|
||
model: Optional[str] = None # 前端指定对话模型(覆盖默认)
|
||
vision_model: Optional[str] = None # 前端指定视觉模型
|
||
image_edit_model: Optional[str] = None # 前端指定图片编辑模型
|
||
|
||
class ToolResultRequest(BaseModel):
|
||
"""前端执行完工具后回传结果"""
|
||
session_id: int
|
||
tool_call_id: str
|
||
tool_name: str
|
||
result: str
|
||
model: Optional[str] = None # 前端指定对话模型
|
||
vision_model: Optional[str] = None
|
||
image_edit_model: Optional[str] = None
|
||
|
||
class GeneratePatternRequest(BaseModel):
|
||
"""从成衣图片生成面料花样"""
|
||
image_base64: str # 成衣图片 base64
|
||
canvas_base64: Optional[str] = None # 裁片轮廓截图 base64
|
||
prompt: Optional[str] = None # 可选自定义提示词
|
||
image_edit_model: Optional[str] = None
|
||
|
||
class VerifyResultRequest(BaseModel):
|
||
"""验证套图结果"""
|
||
garment_base64: str # 原始成衣图片
|
||
canvas_base64: str # 套图后的画布截图
|
||
prompt: Optional[str] = None
|
||
vision_model: Optional[str] = None
|
||
|
||
# ==================== 对话管理接口 ====================
|
||
|
||
@router.get("/ai/models")
|
||
async def get_models(current_username: str = Depends(get_current_user)):
|
||
"""获取可用模型列表和当前默认值"""
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"chat_model": settings.AI_MODEL,
|
||
"vision_model": settings.AI_VISION_MODEL,
|
||
"image_edit_model": settings.AI_IMAGE_EDIT_MODEL,
|
||
"chat_models": [
|
||
{"id": "qwen3-max-2026-01-23", "name": "Qwen3 Max"},
|
||
{"id": "gemini-2.5-pro-thinking", "name": "Gemini 2.5 Pro"},
|
||
],
|
||
"vision_models": [
|
||
{"id": "qwen-vl-max-latest", "name": "Qwen VL Max"},
|
||
{"id": "gemini-2.5-flash-image", "name": "Gemini 2.5 Flash"},
|
||
],
|
||
"image_edit_models": [
|
||
{"id": "qwen-image-edit-max-2026-01-16", "name": "Qwen 图片编辑"},
|
||
{"id": "gemini-3-pro-image-preview", "name": "Gemini 3 Pro 生图"},
|
||
],
|
||
}
|
||
}
|
||
|
||
|
||
@router.get("/ai/sessions")
|
||
async def list_sessions(
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""获取当前用户的对话列表"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
sessions = db.query(ChatSession)\
|
||
.filter(ChatSession.user_id == user.id)\
|
||
.order_by(ChatSession.updated_at.desc())\
|
||
.limit(50).all()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": [{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||
"updated_at": s.updated_at.isoformat() if s.updated_at else None,
|
||
} for s in sessions]
|
||
}
|
||
|
||
|
||
@router.get("/ai/sessions/{session_id}/messages")
|
||
async def get_session_messages(
|
||
session_id: int,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""获取某个对话的全部消息"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
messages = db.query(ChatMessage)\
|
||
.filter(ChatMessage.session_id == session_id)\
|
||
.order_by(ChatMessage.created_at).all()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"title": session.title,
|
||
"messages": [{
|
||
"id": m.id, "role": m.role, "content": m.content,
|
||
"tool_calls": m.tool_calls,
|
||
"created_at": m.created_at.isoformat() if m.created_at else None,
|
||
} for m in messages]
|
||
}
|
||
}
|
||
|
||
|
||
@router.delete("/ai/sessions/{session_id}")
|
||
async def delete_session(
|
||
session_id: int,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""删除对话"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
db.query(ChatMessage).filter(ChatMessage.session_id == session_id).delete()
|
||
db.delete(session)
|
||
db.commit()
|
||
return {"code": 200, "message": "对话已删除"}
|
||
|
||
|
||
# ==================== 聊天接口 ====================
|
||
|
||
@router.post("/ai/chat")
|
||
async def chat(
|
||
data: ChatRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""发送消息 → AI 回复(可能包含工具调用指令)"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Chat] 收到消息: '{data.message[:80]}' session={data.session_id} 有图片={'是' if data.image_base64 else '否'}")
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
# 1. 获取或创建会话
|
||
session = None
|
||
if data.session_id:
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == data.session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
|
||
if not session:
|
||
title = data.message[:30] + ("..." if len(data.message) > 30 else "")
|
||
session = ChatSession(user_id=user.id, username=current_username, title=title)
|
||
db.add(session)
|
||
db.flush()
|
||
|
||
# 2. 保存用户消息(图片消息加标记,不存 base64)
|
||
has_image = bool(data.image_base64)
|
||
save_content = ("[图片分析] " + data.message) if has_image else data.message
|
||
db.add(ChatMessage(session_id=session.id, role="user", content=save_content))
|
||
db.commit()
|
||
|
||
# 3. 加载历史
|
||
history = db.query(ChatMessage)\
|
||
.filter(ChatMessage.session_id == session.id)\
|
||
.order_by(ChatMessage.created_at).all()
|
||
# tool 角色消息转为 user(避免 LLM 验证 tool_calls 结构报错)
|
||
history_list = [{"role": m.role if m.role != "tool" else "user", "content": m.content} for m in history[-20:]]
|
||
|
||
# 4. 调用 LLM(统一走工具模型,图片信息通过文本标记传递)
|
||
try:
|
||
if not settings.AI_API_KEY and not (data.model and is_gemini_model(data.model)):
|
||
reply_content = mock_reply(data.message, has_image)
|
||
tool_calls_data = None
|
||
else:
|
||
call_history = history_list
|
||
# 有图片时:在消息中标注,让工具模型知道图片已上传可供工具使用
|
||
if has_image:
|
||
call_history = history_list[:-1] + [{
|
||
"role": "user",
|
||
"content": f"[用户上传了一张图片,已保存可供工具使用] {data.message or '请处理这张图片'}"
|
||
}]
|
||
|
||
# 根据模型类型路由
|
||
if data.model and is_gemini_model(data.model):
|
||
log.info(f"[Chat] 路由到 Gemini: {data.model}")
|
||
reply_content, tool_calls_data = call_gemini_with_tools(call_history, data.model)
|
||
else:
|
||
reply_content, tool_calls_data = call_llm_with_tools(call_history, model_override=data.model)
|
||
except Exception as e:
|
||
reply_content = f"AI 请求出错: {str(e)}"
|
||
tool_calls_data = None
|
||
|
||
# 5. 保存 AI 回复
|
||
tc_json = json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
|
||
db.add(ChatMessage(
|
||
session_id=session.id, role="assistant",
|
||
content=reply_content or "", tool_calls=tc_json
|
||
))
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"content": reply_content or "",
|
||
"tool_calls": tool_calls_data
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/ai/tool-result")
|
||
async def submit_tool_result(
|
||
data: ToolResultRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""前端执行完工具后回传结果,后端继续让 AI 总结"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[ToolResult] 工具: {data.tool_name}, 结果: {data.result[:150]}")
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == data.session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
# 1. 保存工具结果为一条消息
|
||
db.add(ChatMessage(
|
||
session_id=session.id, role="tool",
|
||
content=f"[工具 {data.tool_name} 执行结果]: {data.result}"
|
||
))
|
||
db.commit()
|
||
|
||
# 2. 重新加载历史,让 AI 根据工具结果继续回答
|
||
history = db.query(ChatMessage)\
|
||
.filter(ChatMessage.session_id == session.id)\
|
||
.order_by(ChatMessage.created_at).all()
|
||
history_list = [{"role": m.role if m.role != "tool" else "user", "content": m.content} for m in history[-20:]]
|
||
|
||
# 3. 再次调用 LLM(这次不带 tools,让 AI 总结结果)
|
||
try:
|
||
if not settings.AI_API_KEY and not (data.model and is_gemini_model(data.model)):
|
||
reply_content = f"工具 {data.tool_name} 执行完成。"
|
||
tool_calls_data = None
|
||
else:
|
||
if data.model and is_gemini_model(data.model):
|
||
log.info(f"[ToolResult] 路由到 Gemini: {data.model}")
|
||
reply_content, tool_calls_data = call_gemini_with_tools(history_list, data.model)
|
||
else:
|
||
reply_content, tool_calls_data = call_llm_with_tools(history_list, model_override=data.model)
|
||
except Exception as e:
|
||
reply_content = f"AI 总结出错: {str(e)}"
|
||
tool_calls_data = None
|
||
|
||
# 4. 保存 AI 总结
|
||
tc_json = json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
|
||
db.add(ChatMessage(
|
||
session_id=session.id, role="assistant",
|
||
content=reply_content or "", tool_calls=tc_json
|
||
))
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"content": reply_content or "",
|
||
"tool_calls": tool_calls_data
|
||
}
|
||
}
|
||
|
||
|
||
|
||
# ==================== 图案生成 & 验证 ====================
|
||
|
||
@router.post("/ai/generate-preview")
|
||
async def generate_preview(
|
||
data: GeneratePatternRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""阶段1:成衣照片 + 裁片截图 → AI 生成带轮廓的花样预览图"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Preview] 收到预览请求, 成衣图: {len(data.image_base64)//1024}KB, 裁片图: {len(data.canvas_base64)//1024 if data.canvas_base64 else 0}KB")
|
||
if data.prompt:
|
||
log.info(f"[Preview] 自定义提示词: {data.prompt[:150]}")
|
||
if not settings.AI_API_KEY or not settings.AI_IMAGE_EDIT_MODEL:
|
||
raise HTTPException(400, "AI_API_KEY 或 AI_IMAGE_EDIT_MODEL 未配置")
|
||
|
||
try:
|
||
# 构造图片列表:成衣 + 裁片截图(如果有)
|
||
images = [data.image_base64]
|
||
if data.canvas_base64:
|
||
images.append(data.canvas_base64)
|
||
|
||
default_prompt = (
|
||
"Image 1 是一件成衣照片,Image 2 是裁片轮廓图。"
|
||
"请将 Image 1 成衣上的面料花样提取出来,填充到 Image 2 的每个裁片轮廓内。"
|
||
"要求:保持花样颜色、比例、方向一致;保留裁片轮廓线;不要改变裁片排列位置。"
|
||
) if data.canvas_base64 else (
|
||
"请从这件成衣中提取面料花样,生成一张干净的花样平铺图。"
|
||
)
|
||
|
||
result_url, desc = call_image_model(
|
||
images_b64=images,
|
||
prompt=data.prompt or default_prompt,
|
||
model_override=data.image_edit_model,
|
||
)
|
||
return {"code": 200, "data": {"image_url": result_url, "description": desc}}
|
||
except Exception as e:
|
||
log.error(f"预览生成失败: {e}", exc_info=True)
|
||
err_msg = str(e)
|
||
if "data_inspection_failed" in err_msg or "inappropriate" in err_msg:
|
||
raise HTTPException(400, "图片内容未通过平台安全审核,请更换图片后重试(避免使用含版权角色/敏感内容的图片)")
|
||
raise HTTPException(500, f"预览生成失败: {err_msg}")
|
||
|
||
|
||
class CropPieceRequest(BaseModel):
|
||
"""从预览图中按坐标裁切指定裁片区域"""
|
||
preview_base64: str # 预览图 base64(从 URL 下载后传入)
|
||
canvas_width: int # PS 画布原始宽度
|
||
canvas_height: int # PS 画布原始高度
|
||
piece_left: float # 裁片在画布中的 left (px)
|
||
piece_top: float # 裁片在画布中的 top (px)
|
||
piece_width: float # 裁片宽度 (px)
|
||
piece_height: float # 裁片高度 (px)
|
||
piece_name: str # 裁片名称(用于日志)
|
||
padding: float = 0.15 # 出血线比例(向外扩 15%)
|
||
|
||
|
||
@router.post("/ai/crop-piece")
|
||
async def crop_piece(
|
||
data: CropPieceRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""从预览图中按坐标裁切指定裁片区域(Pillow,像素级精准)"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[CropPiece] 裁切: {data.piece_name}")
|
||
log.info(f"[CropPiece] 画布: {data.canvas_width}x{data.canvas_height}, 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}")
|
||
|
||
try:
|
||
result_b64 = _crop_piece_from_preview(data)
|
||
return {"code": 200, "data": {"cropped_base64": result_b64, "piece_name": data.piece_name}}
|
||
except Exception as e:
|
||
log.error(f"裁切失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"裁切失败: {str(e)}")
|
||
|
||
|
||
class RefinePieceRequest(BaseModel):
|
||
"""AI 细化裁切后的花样图"""
|
||
cropped_base64: str # 裁切后的图片 base64
|
||
piece_name: str # 裁片名称
|
||
pattern_type: str = "fill_pattern" # fill_pattern / theme_pattern / mixed_fill / mixed_theme
|
||
aspect_ratio: Optional[str] = None # 裁片宽高比,如 "3:4"
|
||
piece_width: Optional[int] = None # 裁片像素宽
|
||
piece_height: Optional[int] = None # 裁片像素高
|
||
prompt: Optional[str] = None
|
||
image_edit_model: Optional[str] = None
|
||
|
||
|
||
@router.post("/ai/refine-piece")
|
||
async def refine_piece(
|
||
data: RefinePieceRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""AI 细化裁切图:根据图案类型生成不同提示词"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[RefinePiece] 细化: {data.piece_name}, 类型: {data.pattern_type}, 比例: {data.aspect_ratio}")
|
||
log.info(f"[RefinePiece] 裁片尺寸: {data.piece_width}x{data.piece_height}, 图片: {len(data.cropped_base64)//1024}KB")
|
||
|
||
if not settings.AI_API_KEY or not settings.AI_IMAGE_EDIT_MODEL:
|
||
raise HTTPException(400, "AI_API_KEY 或 AI_IMAGE_EDIT_MODEL 未配置")
|
||
|
||
try:
|
||
# 尺寸描述
|
||
size_hint = ""
|
||
if data.aspect_ratio:
|
||
size_hint = f"输出图片宽高比必须为 {data.aspect_ratio}。"
|
||
if data.piece_width and data.piece_height:
|
||
size_hint += f"输出尺寸至少 {int(data.piece_width * 1.15)}x{int(data.piece_height * 1.15)} 像素。"
|
||
|
||
# 根据类型生成提示词
|
||
if data.pattern_type == "fill_pattern":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的花型图案区域。"
|
||
f"请将它处理为一张干净的矩形花型图:"
|
||
f"去掉所有轮廓线、标签文字和边缘空白;"
|
||
f"花型图案要完整铺满整个矩形,无空白边距;"
|
||
f"向四周自然延展重复纹样,保持花型连续无接缝;"
|
||
f"保持花样的颜色和细节清晰。{size_hint}"
|
||
)
|
||
elif data.pattern_type == "theme_pattern":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的主题图案区域(如卡通人物、Logo、印花)。"
|
||
f"请提取其中的主题图案,输出为纯白色背景(#FFFFFF)上的主题图案:"
|
||
f"主题图案要完整清晰,保持原始颜色和细节;"
|
||
f"背景必须是纯白色(#FFFFFF),不要任何其他底色或纹理;"
|
||
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
|
||
)
|
||
elif data.pattern_type == "mixed_fill":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
|
||
f"请只提取底部的花型/纹理图案,去掉上面的主题图案(人物/Logo):"
|
||
f"用底纹自然填充主题图案原来的位置;"
|
||
f"花型要完整铺满整个矩形,保持纹样连续;"
|
||
f"去掉轮廓线和标签文字。{size_hint}"
|
||
)
|
||
elif data.pattern_type == "mixed_theme":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
|
||
f"请只提取主题图案(人物/Logo/大面积印花),输出为纯白色背景(#FFFFFF):"
|
||
f"主题图案要完整清晰,保持原始颜色和细节;"
|
||
f"背景必须是纯白色(#FFFFFF),去掉所有底纹;"
|
||
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
|
||
)
|
||
else:
|
||
prompt = data.prompt or f"细化这张图片,去掉轮廓线和标签文字,保持清晰。{size_hint}"
|
||
|
||
log.info(f"[RefinePiece] 提示词: {prompt[:200]}")
|
||
|
||
result_url, desc = call_image_model(
|
||
images_b64=[data.cropped_base64],
|
||
prompt=prompt,
|
||
model_override=data.image_edit_model,
|
||
)
|
||
return {"code": 200, "data": {"refined_url": result_url, "piece_name": data.piece_name, "pattern_type": data.pattern_type}}
|
||
except Exception as e:
|
||
log.error(f"细化失败: {e}", exc_info=True)
|
||
err_msg = str(e)
|
||
if "data_inspection_failed" in err_msg:
|
||
raise HTTPException(400, "图片内容未通过安全审核")
|
||
raise HTTPException(500, f"细化失败: {err_msg}")
|
||
|
||
|
||
def _crop_piece_from_preview(data: CropPieceRequest) -> str:
|
||
"""用 Pillow 从预览图中按 PS 坐标裁切指定区域(带出血线)"""
|
||
from PIL import Image
|
||
import io
|
||
|
||
# 解码预览图
|
||
img_bytes = base64.b64decode(data.preview_base64)
|
||
preview_img = Image.open(io.BytesIO(img_bytes))
|
||
pw, ph = preview_img.size
|
||
|
||
# 坐标映射:PS 画布坐标 → 预览图像素坐标
|
||
scale_x = pw / data.canvas_width
|
||
scale_y = ph / data.canvas_height
|
||
|
||
# 出血线(按裁片尺寸的百分比向外扩)
|
||
bleed_x = data.piece_width * data.padding
|
||
bleed_y = data.piece_height * data.padding
|
||
|
||
# PS 坐标(含出血)→ 预览图像素
|
||
left = max(0, (data.piece_left - bleed_x) * scale_x)
|
||
top = max(0, (data.piece_top - bleed_y) * scale_y)
|
||
right = min(pw, (data.piece_left + data.piece_width + bleed_x) * scale_x)
|
||
bottom = min(ph, (data.piece_top + data.piece_height + bleed_y) * scale_y)
|
||
|
||
log.info(f"[CropPiece] PS 画布: {data.canvas_width}x{data.canvas_height}")
|
||
log.info(f"[CropPiece] 预览图: {pw}x{ph}")
|
||
log.info(f"[CropPiece] 缩放比: x={scale_x:.4f}, y={scale_y:.4f}")
|
||
log.info(f"[CropPiece] PS 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}")
|
||
log.info(f"[CropPiece] 出血线: {data.padding*100:.0f}% → x±{bleed_x:.0f}px, y±{bleed_y:.0f}px")
|
||
log.info(f"[CropPiece] 预览图裁切: ({left:.0f},{top:.0f})-({right:.0f},{bottom:.0f}) = {right-left:.0f}x{bottom-top:.0f}px")
|
||
|
||
cropped = preview_img.crop((int(left), int(top), int(right), int(bottom)))
|
||
|
||
# 调试:保存裁切结果
|
||
import os, time
|
||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||
os.makedirs(debug_dir, exist_ok=True)
|
||
ts = int(time.time())
|
||
debug_path = os.path.join(debug_dir, f'{ts}_crop_{data.piece_name}.png')
|
||
cropped.save(debug_path)
|
||
log.info(f"[CropPiece] 裁切结果已保存: {debug_path}")
|
||
|
||
# 转 base64
|
||
buf = io.BytesIO()
|
||
cropped.save(buf, format='PNG')
|
||
return base64.b64encode(buf.getvalue()).decode()
|
||
|
||
|
||
class ExtractPieceRequest(BaseModel):
|
||
"""从预览图中提取单个裁片花样(旧版 AI 提取,保留兼容)"""
|
||
preview_base64: str
|
||
piece_name: str
|
||
piece_description: str = ""
|
||
prompt: Optional[str] = None
|
||
|
||
|
||
@router.post("/ai/extract-piece-pattern")
|
||
async def extract_piece_pattern(
|
||
data: ExtractPieceRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""阶段2:从预览图中提取指定裁片区域 → 输出矩形花样图"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[ExtractPiece] 裁片: {data.piece_name}, 描述: '{data.piece_description}', 预览图: {len(data.preview_base64)//1024}KB")
|
||
if not settings.AI_API_KEY or not settings.AI_IMAGE_EDIT_MODEL:
|
||
raise HTTPException(400, "AI_API_KEY 或 AI_IMAGE_EDIT_MODEL 未配置")
|
||
|
||
try:
|
||
prompt = data.prompt or (
|
||
f"这是一张服装裁片花样预览图,包含多个裁片。"
|
||
f"请提取其中【{data.piece_name}】的花样区域"
|
||
f"{'(' + data.piece_description + ')' if data.piece_description else ''},"
|
||
f"将该区域的花样输出为一张完整的矩形图片。"
|
||
f"要求:只保留花样内容,去掉轮廓线,填满整个矩形,保持清晰。"
|
||
)
|
||
result_url, desc = call_image_model(
|
||
images_b64=[data.preview_base64],
|
||
prompt=prompt,
|
||
model_override=getattr(data, 'image_edit_model', None),
|
||
)
|
||
return {"code": 200, "data": {"pattern_url": result_url, "piece_name": data.piece_name, "description": desc}}
|
||
except Exception as e:
|
||
log.error(f"裁片提取失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"裁片提取失败: {str(e)}")
|
||
|
||
|
||
class IdentifyPiecesRequest(BaseModel):
|
||
"""识别裁片部位 + 分析颜色花样"""
|
||
canvas_base64: str # 画布截图(裁片轮廓)
|
||
garment_base64: Optional[str] = None # 成衣照片(用于分析颜色花样)
|
||
layers_info: str # 图层信息
|
||
vision_model: Optional[str] = None
|
||
|
||
|
||
@router.post("/ai/identify-pieces")
|
||
async def identify_pieces(
|
||
data: IdentifyPiecesRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""用视觉模型识别裁片部位 + 分析成衣各部位的颜色花样"""
|
||
if not settings.AI_API_KEY:
|
||
raise HTTPException(400, "AI_API_KEY 未配置")
|
||
|
||
try:
|
||
result = _identify_garment_pieces(data.canvas_base64, data.layers_info, data.garment_base64, vision_model=data.vision_model)
|
||
return {"code": 200, "data": result}
|
||
except Exception as e:
|
||
log.error(f"裁片识别失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"裁片识别失败: {str(e)}")
|
||
|
||
|
||
def _identify_garment_pieces(canvas_b64: str, layers_info: str, garment_b64: str = None, vision_model: str = None) -> dict:
|
||
"""用视觉模型分析画布+成衣,识别裁片部位并判断颜色/花样(自动路由 Qwen/Gemini)"""
|
||
|
||
use_model = vision_model or settings.AI_VISION_MODEL
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[IdentifyPieces] 调用视觉模型识别裁片 + 分析花样")
|
||
log.info(f"[IdentifyPieces] 模型: {use_model}")
|
||
log.info(f"[IdentifyPieces] 画布: {len(canvas_b64)//1024}KB, 成衣: {len(garment_b64)//1024 if garment_b64 else 0}KB")
|
||
log.info(f"[IdentifyPieces] 图层:\n{layers_info[:500]}")
|
||
|
||
# 构造提示词
|
||
if garment_b64:
|
||
prompt = f"""你需要完成两个任务:
|
||
|
||
**任务1:识别裁片部位**
|
||
Image 2 是 PS 文档中的裁片轮廓图。以下是图层信息:
|
||
{layers_info}
|
||
|
||
请根据形状、大小识别每个图层是什么服装部位(前片、后片、袖子、领口等)。
|
||
|
||
**任务2:分析成衣颜色花样**
|
||
Image 1 是成衣照片。请分析这件衣服各个部位的视觉特征:
|
||
- 哪些部位有印花/花样/图案?描述花样内容
|
||
- 哪些部位是纯色/素色?给出颜色的 hex 值
|
||
- 衣服整体的底色是什么?
|
||
|
||
用 JSON 回答,格式:
|
||
```json
|
||
{{
|
||
"mapping": {{
|
||
"M-1": "前片",
|
||
"M-2": "后片"
|
||
}},
|
||
"analysis": [
|
||
{{"layer": "M-1", "piece": "前片", "type": "theme_pattern", "color": "#F5E6D0", "description": "卡通女孩图案,米白色纯色底"}},
|
||
{{"layer": "M-2", "piece": "后片", "type": "solid", "color": "#F5E6D0", "description": "米白色纯色"}},
|
||
{{"layer": "M-3", "piece": "左袖", "type": "fill_pattern", "description": "蓝白条纹重复图案"}}
|
||
],
|
||
"base_color": "#F5E6D0"
|
||
}}
|
||
```
|
||
|
||
type 只有四种:
|
||
- solid — 纯色,必须给 color(hex值)
|
||
- fill_pattern — 重复性花型(碎花/条纹/格子/波点),整片铺满相同纹样
|
||
- theme_pattern — 主题图案(卡通/Logo/大印花)在纯色底上,必须给 color(底色 hex值)
|
||
- mixed_pattern — 主题图案叠在花型底纹上(最复杂),需描述主题和底纹
|
||
只返回 JSON,不要其他文字。"""
|
||
else:
|
||
prompt = f"""这是一张 PS 文档裁片截图。图层信息:
|
||
{layers_info}
|
||
|
||
识别每个图层的服装部位,用 JSON 回答:
|
||
```json
|
||
{{"mapping": {{"M-1": "前片", "M-2": "后片"}}}}
|
||
```
|
||
只返回 JSON。"""
|
||
|
||
# ---------- Gemini 路由 ----------
|
||
if is_gemini_model(use_model):
|
||
log.info(f"[IdentifyPieces] 使用 Gemini 视觉: {use_model}")
|
||
images = []
|
||
if garment_b64:
|
||
images.append(garment_b64)
|
||
images.append(canvas_b64)
|
||
|
||
content = call_gemini(
|
||
[{"role": "user", "content": prompt}],
|
||
use_model,
|
||
images_b64=images
|
||
)
|
||
log.info(f"[IdentifyPieces] Gemini 回复: {content[:500]}")
|
||
else:
|
||
# ---------- Qwen / OpenAI 路由 ----------
|
||
from openai import OpenAI
|
||
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
content_parts = []
|
||
if garment_b64:
|
||
content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}})
|
||
content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{canvas_b64}"}})
|
||
content_parts.append({"type": "text", "text": prompt})
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=[{"role": "user", "content": content_parts}],
|
||
)
|
||
|
||
content = completion.choices[0].message.content or ""
|
||
log.info(f"[IdentifyPieces] 视觉模型回复: {content[:500]}")
|
||
|
||
# 提取 JSON
|
||
try:
|
||
clean = content.strip()
|
||
if "```" in clean:
|
||
clean = clean.split("```")[1]
|
||
if clean.startswith("json"):
|
||
clean = clean[4:]
|
||
clean = clean.strip()
|
||
parsed = json.loads(clean)
|
||
|
||
# 兼容旧格式(纯 mapping)和新格式(mapping + analysis)
|
||
if "mapping" in parsed:
|
||
return parsed
|
||
else:
|
||
return {"mapping": parsed}
|
||
except json.JSONDecodeError:
|
||
log.warning(f"[IdentifyPieces] JSON 解析失败")
|
||
return {"mapping": {}, "raw_response": content}
|
||
|
||
|
||
@router.post("/ai/verify-result")
|
||
async def verify_result(
|
||
data: VerifyResultRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""对比原始成衣和套图结果,给出验证反馈"""
|
||
if not settings.AI_API_KEY:
|
||
raise HTTPException(400, "AI_API_KEY 未配置")
|
||
|
||
try:
|
||
feedback = verify_pattern_result(data.garment_base64, data.canvas_base64, data.prompt, vision_model=data.vision_model)
|
||
return {"code": 200, "data": {"feedback": feedback}}
|
||
except Exception as e:
|
||
log.error(f"验证失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"验证失败: {str(e)}")
|
||
|
||
|