1271 lines
51 KiB
Python
1271 lines
51 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 聊天接口
|
||
功能:对话管理 + 消息存库 + function calling(工具调用)
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
import json, base64, re, logging
|
||
from sqlalchemy.orm import Session as DBSession
|
||
from app.core.config import settings
|
||
from app.core.security import get_current_user
|
||
from app.db import get_db
|
||
from app.models.user import User
|
||
from app.models.chat import ChatSession, ChatMessage
|
||
from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES
|
||
|
||
log = logging.getLogger(__name__)
|
||
# 防止 uvicorn --reload 导致 handler 重复
|
||
if not log.handlers:
|
||
log.setLevel(logging.INFO)
|
||
log.propagate = False
|
||
_h = logging.StreamHandler()
|
||
_h.setFormatter(logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] %(message)s'))
|
||
log.addHandler(_h)
|
||
|
||
router = APIRouter()
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
session_id: Optional[int] = None
|
||
image_base64: Optional[str] = None
|
||
model: Optional[str] = None # 前端指定对话模型(覆盖默认)
|
||
vision_model: Optional[str] = None # 前端指定视觉模型
|
||
image_edit_model: Optional[str] = None # 前端指定图片编辑模型
|
||
|
||
class ToolResultRequest(BaseModel):
|
||
"""前端执行完工具后回传结果"""
|
||
session_id: int
|
||
tool_call_id: str
|
||
tool_name: str
|
||
result: str
|
||
model: Optional[str] = None # 前端指定对话模型
|
||
vision_model: Optional[str] = None
|
||
image_edit_model: Optional[str] = None
|
||
|
||
class GeneratePatternRequest(BaseModel):
|
||
"""从成衣图片生成面料花样"""
|
||
image_base64: str # 成衣图片 base64
|
||
canvas_base64: Optional[str] = None # 裁片轮廓截图 base64
|
||
prompt: Optional[str] = None # 可选自定义提示词
|
||
image_edit_model: Optional[str] = None
|
||
|
||
class VerifyResultRequest(BaseModel):
|
||
"""验证套图结果"""
|
||
garment_base64: str # 原始成衣图片
|
||
canvas_base64: str # 套图后的画布截图
|
||
prompt: Optional[str] = None
|
||
vision_model: Optional[str] = None
|
||
|
||
# ==================== System Prompt ====================
|
||
|
||
SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop CEP 插件中。
|
||
|
||
你的能力:
|
||
1. 回答关于 Photoshop 操作和插件使用的问题
|
||
2. 通过工具直接操作 Photoshop(创建图层、对齐、查看文档信息等)
|
||
3. 帮用户排查操作中遇到的错误
|
||
4. **AI 智能套图**:两阶段流程 — 先生成预览确认,再提取套到裁片上
|
||
|
||
## AI 智能套图流程(两阶段)
|
||
|
||
当用户上传成衣图片并要求套图时:
|
||
|
||
### 阶段 1 — 识别裁片 + 生成预览(需用户确认)
|
||
1. 调用 **identify_pieces** — 截取画布并识别每个图层是什么裁片部位(前片、后片、袖子等)
|
||
2. 告诉用户识别结果(如 "M-1=前片, M-2=后片, M-5=左袖...")
|
||
3. 调用 **generate_garment_preview** — 会自动在裁片下方标注名称标签(如 "M-1 前片"),然后截取带标签的画布 + 成衣照片一起发给 AI 生成预览
|
||
4. 预览图显示在聊天中,每个裁片都有标签,**等用户确认 OK 后**进入阶段 2
|
||
|
||
### 阶段 2 — 提取花样 + 正式套图(用户确认后)
|
||
5. 根据用户反馈决定每个裁片的处理方式:
|
||
- 需要花样的裁片 → extract_and_apply_all_pieces 中不设 color 字段
|
||
- 纯色的裁片(如后幅) → 设置 color 字段为对应颜色值(如 "#F5E6D0"),直接 PS 填充不调 AI
|
||
6. 调用 **extract_and_apply_all_pieces** 执行套图
|
||
7. 可选:调用 **verify_pattern_result** 验证效果
|
||
|
||
重要:
|
||
- 阶段 1 完成后必须**等用户说"可以"/"OK"**才执行阶段 2
|
||
- 用户可能会说"袖子纯色"、"后幅不要花样"等,要根据 identify_pieces 的结果对应到正确的图层名
|
||
- 纯色裁片用 color 字段直接填充,不要浪费 AI 提取调用
|
||
|
||
重要规则:
|
||
- 当用户要求执行 PS 操作时,使用工具完成
|
||
- 执行操作前可以先了解当前文档和图层状态
|
||
- 用简洁的中文回答,适合在小面板中阅读
|
||
- 如果工具执行失败,向用户解释原因并建议解决方案
|
||
- 套图时应一次性完成「生成 → 套图 → 验证」全流程,不要中途停下等待用户指令
|
||
"""
|
||
|
||
VISION_PROMPT = """你是一位资深的服装设计分析师,同时也是 DesignerCEP 的 AI 助手。
|
||
|
||
当用户发送服装/成衣图片时,请从以下维度进行专业分析:
|
||
|
||
1. **服装类别** — 上衣/裤子/裙子/连衣裙/外套/配饰等,细分款式
|
||
2. **面料分析** — 根据视觉特征推测面料类型(棉、涤纶、丝绸、针织、牛仔、雪纺等),分析面料质感
|
||
3. **颜色与印花** — 主色调、配色方案、印花/图案类型及工艺(数码印花、丝网印刷、提花等)
|
||
4. **版型特点** — 修身/宽松/A字/H型等,分析领口、袖型、肩线、腰线、下摆处理
|
||
5. **工艺细节** — 缝线工艺、拉链/纽扣/暗扣、口袋设计、装饰细节、包边/锁边
|
||
6. **设计评价** — 设计亮点、风格定位(休闲/正装/运动/时尚等)、目标消费群体
|
||
7. **改进建议** — 如有可改进之处,给出专业建议
|
||
|
||
规则:
|
||
- 如果图片不是服装相关,也请尽力分析图片内容并给出有价值的反馈
|
||
- 如果用户同时提了文字问题,请结合图片和问题一起回答
|
||
- 用清晰、结构化的中文回答,适合在设计工作中参考
|
||
- 回答要专业但不啰嗦,突出重点信息
|
||
"""
|
||
|
||
# ==================== 对话管理接口 ====================
|
||
|
||
@router.get("/ai/models")
|
||
async def get_models(current_username: str = Depends(get_current_user)):
|
||
"""获取可用模型列表和当前默认值"""
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"chat_model": settings.AI_MODEL,
|
||
"vision_model": settings.AI_VISION_MODEL,
|
||
"image_edit_model": settings.AI_IMAGE_EDIT_MODEL,
|
||
"chat_models": [
|
||
{"id": "qwen3-max-2026-01-23", "name": "Qwen3 Max"},
|
||
{"id": "gemini-2.5-pro-thinking", "name": "Gemini 2.5 Pro"},
|
||
],
|
||
"vision_models": [
|
||
{"id": "qwen-vl-max-latest", "name": "Qwen VL Max"},
|
||
{"id": "gemini-2.5-flash-image", "name": "Gemini 2.5 Flash"},
|
||
],
|
||
"image_edit_models": [
|
||
{"id": "qwen-image-edit-max-2026-01-16", "name": "Qwen 图片编辑"},
|
||
{"id": "gemini-3-pro-image-preview", "name": "Gemini 3 Pro 生图"},
|
||
],
|
||
}
|
||
}
|
||
|
||
|
||
@router.get("/ai/sessions")
|
||
async def list_sessions(
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""获取当前用户的对话列表"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
sessions = db.query(ChatSession)\
|
||
.filter(ChatSession.user_id == user.id)\
|
||
.order_by(ChatSession.updated_at.desc())\
|
||
.limit(50).all()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": [{
|
||
"id": s.id,
|
||
"title": s.title,
|
||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||
"updated_at": s.updated_at.isoformat() if s.updated_at else None,
|
||
} for s in sessions]
|
||
}
|
||
|
||
|
||
@router.get("/ai/sessions/{session_id}/messages")
|
||
async def get_session_messages(
|
||
session_id: int,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""获取某个对话的全部消息"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
messages = db.query(ChatMessage)\
|
||
.filter(ChatMessage.session_id == session_id)\
|
||
.order_by(ChatMessage.created_at).all()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"title": session.title,
|
||
"messages": [{
|
||
"id": m.id, "role": m.role, "content": m.content,
|
||
"tool_calls": m.tool_calls,
|
||
"created_at": m.created_at.isoformat() if m.created_at else None,
|
||
} for m in messages]
|
||
}
|
||
}
|
||
|
||
|
||
@router.delete("/ai/sessions/{session_id}")
|
||
async def delete_session(
|
||
session_id: int,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""删除对话"""
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
db.query(ChatMessage).filter(ChatMessage.session_id == session_id).delete()
|
||
db.delete(session)
|
||
db.commit()
|
||
return {"code": 200, "message": "对话已删除"}
|
||
|
||
|
||
# ==================== 聊天接口 ====================
|
||
|
||
@router.post("/ai/chat")
|
||
async def chat(
|
||
data: ChatRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""发送消息 → AI 回复(可能包含工具调用指令)"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Chat] 收到消息: '{data.message[:80]}' session={data.session_id} 有图片={'是' if data.image_base64 else '否'}")
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
# 1. 获取或创建会话
|
||
session = None
|
||
if data.session_id:
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == data.session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
|
||
if not session:
|
||
title = data.message[:30] + ("..." if len(data.message) > 30 else "")
|
||
session = ChatSession(user_id=user.id, username=current_username, title=title)
|
||
db.add(session)
|
||
db.flush()
|
||
|
||
# 2. 保存用户消息(图片消息加标记,不存 base64)
|
||
has_image = bool(data.image_base64)
|
||
save_content = ("[图片分析] " + data.message) if has_image else data.message
|
||
db.add(ChatMessage(session_id=session.id, role="user", content=save_content))
|
||
db.commit()
|
||
|
||
# 3. 加载历史
|
||
history = db.query(ChatMessage)\
|
||
.filter(ChatMessage.session_id == session.id)\
|
||
.order_by(ChatMessage.created_at).all()
|
||
# tool 角色消息转为 user(避免 LLM 验证 tool_calls 结构报错)
|
||
history_list = [{"role": m.role if m.role != "tool" else "user", "content": m.content} for m in history[-20:]]
|
||
|
||
# 4. 调用 LLM(统一走工具模型,图片信息通过文本标记传递)
|
||
try:
|
||
if not settings.AI_API_KEY and not (data.model and _is_gemini_model(data.model)):
|
||
reply_content = _mock_reply(data.message, has_image)
|
||
tool_calls_data = None
|
||
else:
|
||
call_history = history_list
|
||
# 有图片时:在消息中标注,让工具模型知道图片已上传可供工具使用
|
||
if has_image:
|
||
call_history = history_list[:-1] + [{
|
||
"role": "user",
|
||
"content": f"[用户上传了一张图片,已保存可供工具使用] {data.message or '请处理这张图片'}"
|
||
}]
|
||
|
||
# 根据模型类型路由
|
||
if data.model and _is_gemini_model(data.model):
|
||
log.info(f"[Chat] 路由到 Gemini: {data.model}")
|
||
reply_content, tool_calls_data = _call_gemini_with_tools(call_history, data.model)
|
||
else:
|
||
reply_content, tool_calls_data = _call_llm_with_tools(call_history, model_override=data.model)
|
||
except Exception as e:
|
||
reply_content = f"AI 请求出错: {str(e)}"
|
||
tool_calls_data = None
|
||
|
||
# 5. 保存 AI 回复
|
||
tc_json = json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
|
||
db.add(ChatMessage(
|
||
session_id=session.id, role="assistant",
|
||
content=reply_content or "", tool_calls=tc_json
|
||
))
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"content": reply_content or "",
|
||
"tool_calls": tool_calls_data
|
||
}
|
||
}
|
||
|
||
|
||
@router.post("/ai/tool-result")
|
||
async def submit_tool_result(
|
||
data: ToolResultRequest,
|
||
db: DBSession = Depends(get_db),
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""前端执行完工具后回传结果,后端继续让 AI 总结"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[ToolResult] 工具: {data.tool_name}, 结果: {data.result[:150]}")
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
session = db.query(ChatSession).filter(
|
||
ChatSession.id == data.session_id, ChatSession.user_id == user.id
|
||
).first()
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="对话不存在")
|
||
|
||
# 1. 保存工具结果为一条消息
|
||
db.add(ChatMessage(
|
||
session_id=session.id, role="tool",
|
||
content=f"[工具 {data.tool_name} 执行结果]: {data.result}"
|
||
))
|
||
db.commit()
|
||
|
||
# 2. 重新加载历史,让 AI 根据工具结果继续回答
|
||
history = db.query(ChatMessage)\
|
||
.filter(ChatMessage.session_id == session.id)\
|
||
.order_by(ChatMessage.created_at).all()
|
||
history_list = [{"role": m.role if m.role != "tool" else "user", "content": m.content} for m in history[-20:]]
|
||
|
||
# 3. 再次调用 LLM(这次不带 tools,让 AI 总结结果)
|
||
try:
|
||
if not settings.AI_API_KEY and not (data.model and _is_gemini_model(data.model)):
|
||
reply_content = f"工具 {data.tool_name} 执行完成。"
|
||
tool_calls_data = None
|
||
else:
|
||
if data.model and _is_gemini_model(data.model):
|
||
log.info(f"[ToolResult] 路由到 Gemini: {data.model}")
|
||
reply_content, tool_calls_data = _call_gemini_with_tools(history_list, data.model)
|
||
else:
|
||
reply_content, tool_calls_data = _call_llm_with_tools(history_list, model_override=data.model)
|
||
except Exception as e:
|
||
reply_content = f"AI 总结出错: {str(e)}"
|
||
tool_calls_data = None
|
||
|
||
# 4. 保存 AI 总结
|
||
tc_json = json.dumps(tool_calls_data, ensure_ascii=False) if tool_calls_data else None
|
||
db.add(ChatMessage(
|
||
session_id=session.id, role="assistant",
|
||
content=reply_content or "", tool_calls=tc_json
|
||
))
|
||
db.commit()
|
||
|
||
return {
|
||
"code": 200,
|
||
"data": {
|
||
"session_id": session.id,
|
||
"content": reply_content or "",
|
||
"tool_calls": tool_calls_data
|
||
}
|
||
}
|
||
|
||
|
||
# ==================== LLM 调用 ====================
|
||
|
||
def _call_llm_with_tools(messages_history: List[dict], model_override: str = None):
|
||
"""调用 LLM,支持 function calling"""
|
||
from openai import OpenAI
|
||
|
||
use_model = model_override or settings.AI_MODEL
|
||
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages_history
|
||
|
||
# ---- 日志:请求详情 ----
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[LLM] 调用工具模型: {use_model}{' (override)' if model_override else ''}")
|
||
log.info(f"[LLM] 消息数量: {len(messages)} (system + {len(messages_history)} history)")
|
||
for i, m in enumerate(messages_history[-5:]): # 只打最近 5 条
|
||
role = m['role']
|
||
content = m['content'][:120] if m.get('content') else '(empty)'
|
||
log.info(f"[LLM] history[-{len(messages_history)-i}] {role}: {content}")
|
||
log.info(f"[LLM] 工具数量: {len(PS_TOOLS)}")
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=messages,
|
||
tools=PS_TOOLS,
|
||
tool_choice="auto",
|
||
)
|
||
|
||
choice = completion.choices[0]
|
||
message = choice.message
|
||
|
||
# ---- 日志:响应详情 ----
|
||
log.info(f"[LLM] 响应 finish_reason={choice.finish_reason}")
|
||
if message.content:
|
||
log.info(f"[LLM] 回复文本: {message.content[:150]}...")
|
||
if message.tool_calls:
|
||
for tc in message.tool_calls:
|
||
log.info(f"[LLM] 工具调用: {tc.function.name}({tc.function.arguments[:200]})")
|
||
else:
|
||
log.info(f"[LLM] 无工具调用")
|
||
log.info(f"{'='*60}")
|
||
|
||
# 检查是否有工具调用
|
||
if message.tool_calls and len(message.tool_calls) > 0:
|
||
tool_calls_data = []
|
||
for tc in message.tool_calls:
|
||
args = {}
|
||
if tc.function.arguments:
|
||
try:
|
||
args = json.loads(tc.function.arguments)
|
||
except json.JSONDecodeError:
|
||
args = {}
|
||
|
||
tool_calls_data.append({
|
||
"id": tc.id,
|
||
"name": tc.function.name,
|
||
"display_name": TOOL_DISPLAY_NAMES.get(tc.function.name, tc.function.name),
|
||
"args": args,
|
||
"status": "pending"
|
||
})
|
||
|
||
return message.content or "", tool_calls_data
|
||
|
||
# 普通文本回复
|
||
return message.content or "", None
|
||
|
||
|
||
# ==================== Gemini 调用 ====================
|
||
|
||
def _is_gemini_model(model_name: str) -> bool:
|
||
"""判断是否是 Gemini 模型"""
|
||
return model_name and "gemini" in model_name.lower()
|
||
|
||
|
||
def _call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None) -> str:
|
||
"""
|
||
调用 Gemini API(通过第三方代理,OpenAI 兼容格式)
|
||
支持纯文本对话和图片输入(视觉分析)
|
||
返回文本内容
|
||
"""
|
||
from openai import OpenAI as _OpenAI
|
||
|
||
if not settings.GEMINI_API_KEY or not settings.GEMINI_BASE_URL:
|
||
raise ValueError("GEMINI_API_KEY 或 GEMINI_BASE_URL 未配置")
|
||
|
||
client = _OpenAI(
|
||
api_key=settings.GEMINI_API_KEY,
|
||
base_url=f"{settings.GEMINI_BASE_URL}/v1",
|
||
)
|
||
|
||
# 构造 OpenAI 格式消息
|
||
messages = []
|
||
for msg in messages_history:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
# OpenAI 格式支持 system / user / assistant
|
||
if role not in ("system", "user", "assistant"):
|
||
role = "user"
|
||
messages.append({"role": role, "content": content})
|
||
|
||
# 如果有图片,把最后一条 user 消息改成多模态格式
|
||
if images_b64:
|
||
# 找到最后一条 user 消息
|
||
last_user_idx = None
|
||
for i in range(len(messages) - 1, -1, -1):
|
||
if messages[i]["role"] == "user":
|
||
last_user_idx = i
|
||
break
|
||
|
||
if last_user_idx is not None:
|
||
text_content = messages[last_user_idx]["content"]
|
||
multimodal_content = []
|
||
# 先放图片
|
||
for img_b64 in images_b64:
|
||
multimodal_content.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||
})
|
||
# 再放文本
|
||
multimodal_content.append({"type": "text", "text": text_content})
|
||
messages[last_user_idx]["content"] = multimodal_content
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Gemini] 调用模型: {model} (OpenAI 兼容)")
|
||
log.info(f"[Gemini] 消息数: {len(messages)}, 图片数: {len(images_b64) if images_b64 else 0}")
|
||
|
||
completion = client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
)
|
||
|
||
result = completion.choices[0].message.content or ""
|
||
log.info(f"[Gemini] 回复: {result[:200]}...")
|
||
return result
|
||
|
||
|
||
def _call_gemini_with_tools(messages_history: List[dict], model: str) -> tuple:
|
||
"""
|
||
用 Gemini 做对话(不支持 function calling,靠 prompt 引导调用工具)
|
||
返回 (content, tool_calls_data) — tool_calls_data 始终为 None
|
||
注意:Gemini 不原生支持 function calling,但文本对话正常
|
||
"""
|
||
# 加入 system prompt
|
||
full_history = [{"role": "system", "content": SYSTEM_PROMPT}] + messages_history
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Gemini] 调用对话模型: {model}")
|
||
for i, m in enumerate(messages_history[-3:]):
|
||
log.info(f"[Gemini] history[-{len(messages_history)-i}] {m['role']}: {str(m.get('content',''))[:120]}")
|
||
|
||
result = _call_gemini(full_history, model)
|
||
|
||
log.info(f"[Gemini] 回复文本: {result[:150]}...")
|
||
log.info(f"[Gemini] 无工具调用(Gemini 不支持 function calling)")
|
||
|
||
return result, None
|
||
|
||
|
||
def _call_vision_llm(user_message: str, image_base64: str, history: List[dict], model_override: str = None) -> str:
|
||
"""调用视觉模型分析图片(自动路由 Qwen / Gemini)"""
|
||
|
||
use_model = model_override or settings.AI_VISION_MODEL
|
||
|
||
# ---------- Gemini 路由 ----------
|
||
if _is_gemini_model(use_model):
|
||
log.info(f"[Vision] 使用 Gemini 视觉模型: {use_model}")
|
||
msgs = [{"role": "system", "content": VISION_PROMPT}]
|
||
for h in history[-10:]:
|
||
role = h["role"] if h["role"] != "tool" else "user"
|
||
msgs.append({"role": role, "content": h["content"]})
|
||
msgs.append({"role": "user", "content": user_message})
|
||
return _call_gemini(msgs, use_model, images_b64=[image_base64])
|
||
|
||
# ---------- Qwen / OpenAI 路由 ----------
|
||
from openai import OpenAI
|
||
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
# 构造多模态消息
|
||
user_content = [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}
|
||
},
|
||
{
|
||
"type": "text",
|
||
"text": user_message
|
||
}
|
||
]
|
||
|
||
# 组装消息列表(历史 + 当前图片消息)
|
||
messages = [
|
||
{"role": "system", "content": VISION_PROMPT},
|
||
]
|
||
# 添加近期历史(纯文本)
|
||
for h in history[-10:]:
|
||
role = h["role"] if h["role"] != "tool" else "user"
|
||
messages.append({"role": role, "content": h["content"]})
|
||
# 当前带图片的用户消息
|
||
messages.append({"role": "user", "content": user_content})
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=messages,
|
||
)
|
||
|
||
return completion.choices[0].message.content or ""
|
||
|
||
|
||
def _mock_reply(message: str, has_image: bool = False) -> str:
|
||
"""未配置 API Key 时的模拟回复"""
|
||
if has_image:
|
||
return "【模拟分析】收到图片。AI 分析功能需要配置 AI_API_KEY 和 AI_VISION_MODEL。\n\n配置后可以分析:服装类别、面料、颜色印花、版型、工艺细节等。"
|
||
if "套图" in message:
|
||
return "套图功能在「参数预设」页面。先选择花样组和裁片组,再添加规则,最后点击生成。"
|
||
elif "对齐" in message:
|
||
return "图层对齐功能在「参数预设」页面顶部。支持上下左右居中对齐、领口对齐等。"
|
||
elif "裁片" in message or "PLT" in message:
|
||
return "PLT 裁片处理在「PLT 裁片处理」页面。上传 PLT 文件,选择尺码,点击开始处理。"
|
||
else:
|
||
return "你好!我是 DesignerCEP AI 助手。你可以问我关于套图、裁片、对齐等功能的问题。"
|
||
|
||
|
||
# ==================== 图案生成 & 验证 ====================
|
||
|
||
@router.post("/ai/generate-preview")
|
||
async def generate_preview(
|
||
data: GeneratePatternRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""阶段1:成衣照片 + 裁片截图 → AI 生成带轮廓的花样预览图"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Preview] 收到预览请求, 成衣图: {len(data.image_base64)//1024}KB, 裁片图: {len(data.canvas_base64)//1024 if data.canvas_base64 else 0}KB")
|
||
if data.prompt:
|
||
log.info(f"[Preview] 自定义提示词: {data.prompt[:150]}")
|
||
if not settings.AI_API_KEY or not settings.AI_IMAGE_EDIT_MODEL:
|
||
raise HTTPException(400, "AI_API_KEY 或 AI_IMAGE_EDIT_MODEL 未配置")
|
||
|
||
try:
|
||
# 构造图片列表:成衣 + 裁片截图(如果有)
|
||
images = [data.image_base64]
|
||
if data.canvas_base64:
|
||
images.append(data.canvas_base64)
|
||
|
||
default_prompt = (
|
||
"Image 1 是一件成衣照片,Image 2 是裁片轮廓图。"
|
||
"请将 Image 1 成衣上的面料花样提取出来,填充到 Image 2 的每个裁片轮廓内。"
|
||
"要求:保持花样颜色、比例、方向一致;保留裁片轮廓线;不要改变裁片排列位置。"
|
||
) if data.canvas_base64 else (
|
||
"请从这件成衣中提取面料花样,生成一张干净的花样平铺图。"
|
||
)
|
||
|
||
result_url, desc = _call_image_model(
|
||
images_b64=images,
|
||
prompt=data.prompt or default_prompt,
|
||
model_override=data.image_edit_model,
|
||
)
|
||
return {"code": 200, "data": {"image_url": result_url, "description": desc}}
|
||
except Exception as e:
|
||
log.error(f"预览生成失败: {e}", exc_info=True)
|
||
err_msg = str(e)
|
||
if "data_inspection_failed" in err_msg or "inappropriate" in err_msg:
|
||
raise HTTPException(400, "图片内容未通过平台安全审核,请更换图片后重试(避免使用含版权角色/敏感内容的图片)")
|
||
raise HTTPException(500, f"预览生成失败: {err_msg}")
|
||
|
||
|
||
class CropPieceRequest(BaseModel):
|
||
"""从预览图中按坐标裁切指定裁片区域"""
|
||
preview_base64: str # 预览图 base64(从 URL 下载后传入)
|
||
canvas_width: int # PS 画布原始宽度
|
||
canvas_height: int # PS 画布原始高度
|
||
piece_left: float # 裁片在画布中的 left (px)
|
||
piece_top: float # 裁片在画布中的 top (px)
|
||
piece_width: float # 裁片宽度 (px)
|
||
piece_height: float # 裁片高度 (px)
|
||
piece_name: str # 裁片名称(用于日志)
|
||
padding: float = 0.15 # 出血线比例(向外扩 15%)
|
||
|
||
|
||
@router.post("/ai/crop-piece")
|
||
async def crop_piece(
|
||
data: CropPieceRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""从预览图中按坐标裁切指定裁片区域(Pillow,像素级精准)"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[CropPiece] 裁切: {data.piece_name}")
|
||
log.info(f"[CropPiece] 画布: {data.canvas_width}x{data.canvas_height}, 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}")
|
||
|
||
try:
|
||
result_b64 = _crop_piece_from_preview(data)
|
||
return {"code": 200, "data": {"cropped_base64": result_b64, "piece_name": data.piece_name}}
|
||
except Exception as e:
|
||
log.error(f"裁切失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"裁切失败: {str(e)}")
|
||
|
||
|
||
class RefinePieceRequest(BaseModel):
|
||
"""AI 细化裁切后的花样图"""
|
||
cropped_base64: str # 裁切后的图片 base64
|
||
piece_name: str # 裁片名称
|
||
pattern_type: str = "fill_pattern" # fill_pattern / theme_pattern / mixed_fill / mixed_theme
|
||
aspect_ratio: Optional[str] = None # 裁片宽高比,如 "3:4"
|
||
piece_width: Optional[int] = None # 裁片像素宽
|
||
piece_height: Optional[int] = None # 裁片像素高
|
||
prompt: Optional[str] = None
|
||
image_edit_model: Optional[str] = None
|
||
|
||
|
||
@router.post("/ai/refine-piece")
|
||
async def refine_piece(
|
||
data: RefinePieceRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""AI 细化裁切图:根据图案类型生成不同提示词"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[RefinePiece] 细化: {data.piece_name}, 类型: {data.pattern_type}, 比例: {data.aspect_ratio}")
|
||
log.info(f"[RefinePiece] 裁片尺寸: {data.piece_width}x{data.piece_height}, 图片: {len(data.cropped_base64)//1024}KB")
|
||
|
||
if not settings.AI_API_KEY or not settings.AI_IMAGE_EDIT_MODEL:
|
||
raise HTTPException(400, "AI_API_KEY 或 AI_IMAGE_EDIT_MODEL 未配置")
|
||
|
||
try:
|
||
# 尺寸描述
|
||
size_hint = ""
|
||
if data.aspect_ratio:
|
||
size_hint = f"输出图片宽高比必须为 {data.aspect_ratio}。"
|
||
if data.piece_width and data.piece_height:
|
||
size_hint += f"输出尺寸至少 {int(data.piece_width * 1.15)}x{int(data.piece_height * 1.15)} 像素。"
|
||
|
||
# 根据类型生成提示词
|
||
if data.pattern_type == "fill_pattern":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的花型图案区域。"
|
||
f"请将它处理为一张干净的矩形花型图:"
|
||
f"去掉所有轮廓线、标签文字和边缘空白;"
|
||
f"花型图案要完整铺满整个矩形,无空白边距;"
|
||
f"向四周自然延展重复纹样,保持花型连续无接缝;"
|
||
f"保持花样的颜色和细节清晰。{size_hint}"
|
||
)
|
||
elif data.pattern_type == "theme_pattern":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的主题图案区域(如卡通人物、Logo、印花)。"
|
||
f"请提取其中的主题图案,输出为纯白色背景(#FFFFFF)上的主题图案:"
|
||
f"主题图案要完整清晰,保持原始颜色和细节;"
|
||
f"背景必须是纯白色(#FFFFFF),不要任何其他底色或纹理;"
|
||
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
|
||
)
|
||
elif data.pattern_type == "mixed_fill":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
|
||
f"请只提取底部的花型/纹理图案,去掉上面的主题图案(人物/Logo):"
|
||
f"用底纹自然填充主题图案原来的位置;"
|
||
f"花型要完整铺满整个矩形,保持纹样连续;"
|
||
f"去掉轮廓线和标签文字。{size_hint}"
|
||
)
|
||
elif data.pattern_type == "mixed_theme":
|
||
prompt = data.prompt or (
|
||
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
|
||
f"请只提取主题图案(人物/Logo/大面积印花),输出为纯白色背景(#FFFFFF):"
|
||
f"主题图案要完整清晰,保持原始颜色和细节;"
|
||
f"背景必须是纯白色(#FFFFFF),去掉所有底纹;"
|
||
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
|
||
)
|
||
else:
|
||
prompt = data.prompt or f"细化这张图片,去掉轮廓线和标签文字,保持清晰。{size_hint}"
|
||
|
||
log.info(f"[RefinePiece] 提示词: {prompt[:200]}")
|
||
|
||
result_url, desc = _call_image_model(
|
||
images_b64=[data.cropped_base64],
|
||
prompt=prompt,
|
||
model_override=data.image_edit_model,
|
||
)
|
||
return {"code": 200, "data": {"refined_url": result_url, "piece_name": data.piece_name, "pattern_type": data.pattern_type}}
|
||
except Exception as e:
|
||
log.error(f"细化失败: {e}", exc_info=True)
|
||
err_msg = str(e)
|
||
if "data_inspection_failed" in err_msg:
|
||
raise HTTPException(400, "图片内容未通过安全审核")
|
||
raise HTTPException(500, f"细化失败: {err_msg}")
|
||
|
||
|
||
def _crop_piece_from_preview(data: CropPieceRequest) -> str:
|
||
"""用 Pillow 从预览图中按 PS 坐标裁切指定区域(带出血线)"""
|
||
from PIL import Image
|
||
import io
|
||
|
||
# 解码预览图
|
||
img_bytes = base64.b64decode(data.preview_base64)
|
||
preview_img = Image.open(io.BytesIO(img_bytes))
|
||
pw, ph = preview_img.size
|
||
|
||
# 坐标映射:PS 画布坐标 → 预览图像素坐标
|
||
scale_x = pw / data.canvas_width
|
||
scale_y = ph / data.canvas_height
|
||
|
||
# 出血线(按裁片尺寸的百分比向外扩)
|
||
bleed_x = data.piece_width * data.padding
|
||
bleed_y = data.piece_height * data.padding
|
||
|
||
# PS 坐标(含出血)→ 预览图像素
|
||
left = max(0, (data.piece_left - bleed_x) * scale_x)
|
||
top = max(0, (data.piece_top - bleed_y) * scale_y)
|
||
right = min(pw, (data.piece_left + data.piece_width + bleed_x) * scale_x)
|
||
bottom = min(ph, (data.piece_top + data.piece_height + bleed_y) * scale_y)
|
||
|
||
log.info(f"[CropPiece] PS 画布: {data.canvas_width}x{data.canvas_height}")
|
||
log.info(f"[CropPiece] 预览图: {pw}x{ph}")
|
||
log.info(f"[CropPiece] 缩放比: x={scale_x:.4f}, y={scale_y:.4f}")
|
||
log.info(f"[CropPiece] PS 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}")
|
||
log.info(f"[CropPiece] 出血线: {data.padding*100:.0f}% → x±{bleed_x:.0f}px, y±{bleed_y:.0f}px")
|
||
log.info(f"[CropPiece] 预览图裁切: ({left:.0f},{top:.0f})-({right:.0f},{bottom:.0f}) = {right-left:.0f}x{bottom-top:.0f}px")
|
||
|
||
cropped = preview_img.crop((int(left), int(top), int(right), int(bottom)))
|
||
|
||
# 调试:保存裁切结果
|
||
import os, time
|
||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||
os.makedirs(debug_dir, exist_ok=True)
|
||
ts = int(time.time())
|
||
debug_path = os.path.join(debug_dir, f'{ts}_crop_{data.piece_name}.png')
|
||
cropped.save(debug_path)
|
||
log.info(f"[CropPiece] 裁切结果已保存: {debug_path}")
|
||
|
||
# 转 base64
|
||
buf = io.BytesIO()
|
||
cropped.save(buf, format='PNG')
|
||
return base64.b64encode(buf.getvalue()).decode()
|
||
|
||
|
||
class ExtractPieceRequest(BaseModel):
|
||
"""从预览图中提取单个裁片花样(旧版 AI 提取,保留兼容)"""
|
||
preview_base64: str
|
||
piece_name: str
|
||
piece_description: str = ""
|
||
prompt: Optional[str] = None
|
||
|
||
|
||
@router.post("/ai/extract-piece-pattern")
|
||
async def extract_piece_pattern(
|
||
data: ExtractPieceRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""阶段2:从预览图中提取指定裁片区域 → 输出矩形花样图"""
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[ExtractPiece] 裁片: {data.piece_name}, 描述: '{data.piece_description}', 预览图: {len(data.preview_base64)//1024}KB")
|
||
if not settings.AI_API_KEY or not settings.AI_IMAGE_EDIT_MODEL:
|
||
raise HTTPException(400, "AI_API_KEY 或 AI_IMAGE_EDIT_MODEL 未配置")
|
||
|
||
try:
|
||
prompt = data.prompt or (
|
||
f"这是一张服装裁片花样预览图,包含多个裁片。"
|
||
f"请提取其中【{data.piece_name}】的花样区域"
|
||
f"{'(' + data.piece_description + ')' if data.piece_description else ''},"
|
||
f"将该区域的花样输出为一张完整的矩形图片。"
|
||
f"要求:只保留花样内容,去掉轮廓线,填满整个矩形,保持清晰。"
|
||
)
|
||
result_url, desc = _call_image_model(
|
||
images_b64=[data.preview_base64],
|
||
prompt=prompt,
|
||
model_override=getattr(data, 'image_edit_model', None),
|
||
)
|
||
return {"code": 200, "data": {"pattern_url": result_url, "piece_name": data.piece_name, "description": desc}}
|
||
except Exception as e:
|
||
log.error(f"裁片提取失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"裁片提取失败: {str(e)}")
|
||
|
||
|
||
class IdentifyPiecesRequest(BaseModel):
|
||
"""识别裁片部位 + 分析颜色花样"""
|
||
canvas_base64: str # 画布截图(裁片轮廓)
|
||
garment_base64: Optional[str] = None # 成衣照片(用于分析颜色花样)
|
||
layers_info: str # 图层信息
|
||
vision_model: Optional[str] = None
|
||
|
||
|
||
@router.post("/ai/identify-pieces")
|
||
async def identify_pieces(
|
||
data: IdentifyPiecesRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""用视觉模型识别裁片部位 + 分析成衣各部位的颜色花样"""
|
||
if not settings.AI_API_KEY:
|
||
raise HTTPException(400, "AI_API_KEY 未配置")
|
||
|
||
try:
|
||
result = _identify_garment_pieces(data.canvas_base64, data.layers_info, data.garment_base64, vision_model=data.vision_model)
|
||
return {"code": 200, "data": result}
|
||
except Exception as e:
|
||
log.error(f"裁片识别失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"裁片识别失败: {str(e)}")
|
||
|
||
|
||
def _identify_garment_pieces(canvas_b64: str, layers_info: str, garment_b64: str = None, vision_model: str = None) -> dict:
|
||
"""用视觉模型分析画布+成衣,识别裁片部位并判断颜色/花样(自动路由 Qwen/Gemini)"""
|
||
|
||
use_model = vision_model or settings.AI_VISION_MODEL
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[IdentifyPieces] 调用视觉模型识别裁片 + 分析花样")
|
||
log.info(f"[IdentifyPieces] 模型: {use_model}")
|
||
log.info(f"[IdentifyPieces] 画布: {len(canvas_b64)//1024}KB, 成衣: {len(garment_b64)//1024 if garment_b64 else 0}KB")
|
||
log.info(f"[IdentifyPieces] 图层:\n{layers_info[:500]}")
|
||
|
||
# 构造提示词
|
||
if garment_b64:
|
||
prompt = f"""你需要完成两个任务:
|
||
|
||
**任务1:识别裁片部位**
|
||
Image 2 是 PS 文档中的裁片轮廓图。以下是图层信息:
|
||
{layers_info}
|
||
|
||
请根据形状、大小识别每个图层是什么服装部位(前片、后片、袖子、领口等)。
|
||
|
||
**任务2:分析成衣颜色花样**
|
||
Image 1 是成衣照片。请分析这件衣服各个部位的视觉特征:
|
||
- 哪些部位有印花/花样/图案?描述花样内容
|
||
- 哪些部位是纯色/素色?给出颜色的 hex 值
|
||
- 衣服整体的底色是什么?
|
||
|
||
用 JSON 回答,格式:
|
||
```json
|
||
{{
|
||
"mapping": {{
|
||
"M-1": "前片",
|
||
"M-2": "后片"
|
||
}},
|
||
"analysis": [
|
||
{{"layer": "M-1", "piece": "前片", "type": "theme_pattern", "color": "#F5E6D0", "description": "卡通女孩图案,米白色纯色底"}},
|
||
{{"layer": "M-2", "piece": "后片", "type": "solid", "color": "#F5E6D0", "description": "米白色纯色"}},
|
||
{{"layer": "M-3", "piece": "左袖", "type": "fill_pattern", "description": "蓝白条纹重复图案"}}
|
||
],
|
||
"base_color": "#F5E6D0"
|
||
}}
|
||
```
|
||
|
||
type 只有四种:
|
||
- solid — 纯色,必须给 color(hex值)
|
||
- fill_pattern — 重复性花型(碎花/条纹/格子/波点),整片铺满相同纹样
|
||
- theme_pattern — 主题图案(卡通/Logo/大印花)在纯色底上,必须给 color(底色 hex值)
|
||
- mixed_pattern — 主题图案叠在花型底纹上(最复杂),需描述主题和底纹
|
||
只返回 JSON,不要其他文字。"""
|
||
else:
|
||
prompt = f"""这是一张 PS 文档裁片截图。图层信息:
|
||
{layers_info}
|
||
|
||
识别每个图层的服装部位,用 JSON 回答:
|
||
```json
|
||
{{"mapping": {{"M-1": "前片", "M-2": "后片"}}}}
|
||
```
|
||
只返回 JSON。"""
|
||
|
||
# ---------- Gemini 路由 ----------
|
||
if _is_gemini_model(use_model):
|
||
log.info(f"[IdentifyPieces] 使用 Gemini 视觉: {use_model}")
|
||
images = []
|
||
if garment_b64:
|
||
images.append(garment_b64)
|
||
images.append(canvas_b64)
|
||
|
||
content = _call_gemini(
|
||
[{"role": "user", "content": prompt}],
|
||
use_model,
|
||
images_b64=images
|
||
)
|
||
log.info(f"[IdentifyPieces] Gemini 回复: {content[:500]}")
|
||
else:
|
||
# ---------- Qwen / OpenAI 路由 ----------
|
||
from openai import OpenAI
|
||
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
content_parts = []
|
||
if garment_b64:
|
||
content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}})
|
||
content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{canvas_b64}"}})
|
||
content_parts.append({"type": "text", "text": prompt})
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=[{"role": "user", "content": content_parts}],
|
||
)
|
||
|
||
content = completion.choices[0].message.content or ""
|
||
log.info(f"[IdentifyPieces] 视觉模型回复: {content[:500]}")
|
||
|
||
# 提取 JSON
|
||
try:
|
||
clean = content.strip()
|
||
if "```" in clean:
|
||
clean = clean.split("```")[1]
|
||
if clean.startswith("json"):
|
||
clean = clean[4:]
|
||
clean = clean.strip()
|
||
parsed = json.loads(clean)
|
||
|
||
# 兼容旧格式(纯 mapping)和新格式(mapping + analysis)
|
||
if "mapping" in parsed:
|
||
return parsed
|
||
else:
|
||
return {"mapping": parsed}
|
||
except json.JSONDecodeError:
|
||
log.warning(f"[IdentifyPieces] JSON 解析失败")
|
||
return {"mapping": {}, "raw_response": content}
|
||
|
||
|
||
@router.post("/ai/verify-result")
|
||
async def verify_result(
|
||
data: VerifyResultRequest,
|
||
current_username: str = Depends(get_current_user)
|
||
):
|
||
"""对比原始成衣和套图结果,给出验证反馈"""
|
||
if not settings.AI_API_KEY:
|
||
raise HTTPException(400, "AI_API_KEY 未配置")
|
||
|
||
try:
|
||
feedback = _verify_pattern_result(data.garment_base64, data.canvas_base64, data.prompt, vision_model=data.vision_model)
|
||
return {"code": 200, "data": {"feedback": feedback}}
|
||
except Exception as e:
|
||
log.error(f"验证失败: {e}", exc_info=True)
|
||
raise HTTPException(500, f"验证失败: {str(e)}")
|
||
|
||
|
||
# ==================== 公共:图片模型调用 + 响应解析 ====================
|
||
|
||
def _call_image_model(images_b64: List[str], prompt: str, model_override: str = None) -> tuple:
|
||
"""
|
||
调用图片编辑/生成模型(自动路由 DashScope / Gemini)
|
||
返回 (result_url_or_base64, description)
|
||
"""
|
||
import requests as http_requests
|
||
|
||
use_model = model_override or settings.AI_IMAGE_EDIT_MODEL
|
||
|
||
# ---------- Gemini 图片生成路由(OpenAI 兼容格式,走代理) ----------
|
||
if _is_gemini_model(use_model):
|
||
log.info(f"[ImageModel] 使用 Gemini 图片模型: {use_model}")
|
||
|
||
if not settings.GEMINI_API_KEY or not settings.GEMINI_BASE_URL:
|
||
raise ValueError("GEMINI_API_KEY 或 GEMINI_BASE_URL 未配置")
|
||
|
||
from openai import OpenAI as _OpenAI
|
||
|
||
# 第三方代理走 OpenAI 兼容接口(/v1/chat/completions)
|
||
client = _OpenAI(
|
||
api_key=settings.GEMINI_API_KEY,
|
||
base_url=f"{settings.GEMINI_BASE_URL}/v1",
|
||
)
|
||
|
||
# 构造 OpenAI 格式的多模态消息(和正常调用一样)
|
||
content_parts = [{"type": "text", "text": prompt}]
|
||
for img_b64 in images_b64:
|
||
content_parts.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||
})
|
||
|
||
log.info(f"[ImageModel] Gemini OpenAI 兼容模式")
|
||
log.info(f"[ImageModel] 模型: {use_model}")
|
||
log.info(f"[ImageModel] 图片数: {len(images_b64)}, 各图大小: {[len(b)//1024 for b in images_b64]}KB")
|
||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=[{"role": "user", "content": content_parts}],
|
||
)
|
||
|
||
result_content = completion.choices[0].message.content or ""
|
||
log.info(f"[ImageModel] Gemini 回复长度: {len(result_content)} chars")
|
||
|
||
# 从 markdown 中提取 base64 图片:
|
||
import re
|
||
match = re.search(r'!\[.*?\]\((data:image/(\w+);base64,([^)]+))\)', result_content)
|
||
|
||
if not match:
|
||
# 也尝试直接匹配 data URI(有些响应不带 markdown 格式)
|
||
match2 = re.search(r'(data:image/(\w+);base64,([A-Za-z0-9+/=]+))', result_content)
|
||
if match2:
|
||
match = match2
|
||
|
||
if not match:
|
||
log.warning(f"[ImageModel] Gemini 响应中无图片,前500字: {result_content[:500]}")
|
||
raise ValueError("Gemini 未返回图片,请检查模型是否支持图片生成")
|
||
|
||
data_uri = match.group(1)
|
||
img_format = match.group(2)
|
||
image_b64 = match.group(3)
|
||
|
||
# base64 填充修正
|
||
padding = 4 - len(image_b64) % 4
|
||
if padding != 4:
|
||
image_b64 += '=' * padding
|
||
|
||
log.info(f"[ImageModel] Gemini 返回图片: image/{img_format}, {len(image_b64)//1024}KB")
|
||
|
||
# 保存调试图片
|
||
import os, time as _time
|
||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||
os.makedirs(debug_dir, exist_ok=True)
|
||
ts = int(_time.time())
|
||
debug_path = os.path.join(debug_dir, f'{ts}_gemini_output.{img_format}')
|
||
try:
|
||
with open(debug_path, 'wb') as f:
|
||
f.write(base64.b64decode(image_b64))
|
||
log.info(f"[ImageModel] Gemini 输出图片已保存: {debug_path}")
|
||
except Exception as e:
|
||
log.warning(f"[ImageModel] 保存调试图片失败: {e}")
|
||
|
||
# 提取文本描述(去掉图片部分)
|
||
description = re.sub(r'!\[.*?\]\(data:image/[^)]+\)', '', result_content).strip()
|
||
|
||
# 返回 data URI(前端可直接用于 <img src>)
|
||
image_data_uri = f"data:image/{img_format};base64,{image_b64}"
|
||
log.info(f"[ImageModel] Gemini 生成完成")
|
||
return image_data_uri, description
|
||
|
||
# ---------- DashScope 原生接口 ----------
|
||
# DashScope 图片编辑模型用原生接口,不用 /compatible-mode
|
||
api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||
|
||
# 构造 DashScope 原生格式的消息
|
||
content_parts = []
|
||
for img_b64 in images_b64:
|
||
content_parts.append({"image": f"data:image/jpeg;base64,{img_b64}"})
|
||
content_parts.append({"text": prompt})
|
||
|
||
payload = {
|
||
"model": use_model,
|
||
"input": {
|
||
"messages": [{
|
||
"role": "user",
|
||
"content": content_parts
|
||
}]
|
||
},
|
||
"parameters": {
|
||
"n": 1,
|
||
"watermark": False,
|
||
"prompt_extend": True,
|
||
}
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {settings.AI_API_KEY}",
|
||
}
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[ImageModel] 调用 DashScope 原生 API")
|
||
log.info(f"[ImageModel] 模型: {use_model}")
|
||
log.info(f"[ImageModel] 图片数: {len(images_b64)}, 各图大小: {[len(b)//1024 for b in images_b64]}KB")
|
||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||
log.info(f"[ImageModel] 端点: {api_url}")
|
||
|
||
# 调试:把发给模型的图片保存到磁盘
|
||
import os
|
||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||
os.makedirs(debug_dir, exist_ok=True)
|
||
import time
|
||
ts = int(time.time())
|
||
for idx, img_b64 in enumerate(images_b64):
|
||
debug_path = os.path.join(debug_dir, f'{ts}_input_{idx}.jpg')
|
||
try:
|
||
with open(debug_path, 'wb') as f:
|
||
f.write(base64.b64decode(img_b64))
|
||
log.info(f"[ImageModel] 调试图片已保存: {debug_path}")
|
||
except Exception as e:
|
||
log.warning(f"[ImageModel] 保存调试图片失败: {e}")
|
||
|
||
resp = http_requests.post(api_url, json=payload, headers=headers, timeout=120)
|
||
|
||
if resp.status_code != 200:
|
||
error_data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||
err_msg = error_data.get("message", resp.text[:300])
|
||
err_code = error_data.get("code", "")
|
||
log.error(f"[ImageModel] API 错误: {resp.status_code} {err_code} {err_msg}")
|
||
if "data_inspection_failed" in str(error_data):
|
||
raise ValueError("图片内容未通过安全审核,请更换图片")
|
||
raise ValueError(f"图片模型调用失败({resp.status_code}): {err_msg}")
|
||
|
||
data = resp.json()
|
||
log.info(f"[ImageModel] 响应 keys: {list(data.keys())}")
|
||
|
||
# 解析响应:output.choices[0].message.content[].image
|
||
output = data.get("output", {})
|
||
choices = output.get("choices", [])
|
||
|
||
if not choices:
|
||
log.warning(f"[ImageModel] 无 choices: {str(data)[:500]}")
|
||
raise ValueError("模型未返回结果")
|
||
|
||
content_list = choices[0].get("message", {}).get("content", [])
|
||
result_b64 = None
|
||
description = ""
|
||
|
||
image_url = None
|
||
for item in content_list:
|
||
if isinstance(item, dict):
|
||
if "image" in item:
|
||
image_url = item["image"]
|
||
log.info(f"[ImageModel] 获取到图片 URL: {image_url[:100]}...")
|
||
elif "text" in item:
|
||
description += item["text"]
|
||
|
||
if not image_url:
|
||
log.warning(f"[ImageModel] content 中无图片: {str(content_list)[:500]}")
|
||
raise ValueError("模型未返回图片")
|
||
|
||
# 调试:保存输出图片 URL
|
||
log.info(f"[ImageModel] 输出图片 URL: {image_url[:120]}...")
|
||
try:
|
||
import os, time
|
||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||
ts = int(time.time())
|
||
out_resp = http_requests.get(image_url, timeout=60)
|
||
if out_resp.status_code == 200:
|
||
debug_path = os.path.join(debug_dir, f'{ts}_output.png')
|
||
with open(debug_path, 'wb') as f:
|
||
f.write(out_resp.content)
|
||
log.info(f"[ImageModel] 输出图片已保存: {debug_path}")
|
||
except Exception as e:
|
||
log.warning(f"[ImageModel] 保存输出图片失败: {e}")
|
||
|
||
return image_url, description
|
||
|
||
|
||
def _verify_pattern_result(garment_b64: str, canvas_b64: str, extra_prompt: str = None, vision_model: str = None) -> str:
|
||
"""用视觉模型对比原始成衣和套图结果(自动路由 Qwen/Gemini)"""
|
||
|
||
use_model = vision_model or settings.AI_VISION_MODEL
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Verify] 调用视觉模型验证套图效果")
|
||
log.info(f"[Verify] 模型: {use_model}")
|
||
log.info(f"[Verify] 成衣图: {len(garment_b64)//1024}KB, 画布图: {len(canvas_b64)//1024}KB")
|
||
if extra_prompt:
|
||
log.info(f"[Verify] 用户补充: {extra_prompt[:100]}")
|
||
|
||
verify_prompt = (
|
||
"请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。\n"
|
||
"验证:1. 花样还原度(颜色/比例/方向)2. 裁片覆盖完整度 "
|
||
"3. 对齐质量 4. 整体效果。给出评分(1-10)和具体改进建议。"
|
||
)
|
||
if extra_prompt:
|
||
verify_prompt += f"\n用户补充:{extra_prompt}"
|
||
|
||
# ---------- Gemini 路由 ----------
|
||
if _is_gemini_model(use_model):
|
||
log.info(f"[Verify] 使用 Gemini 视觉: {use_model}")
|
||
return _call_gemini(
|
||
[
|
||
{"role": "user", "content": "你是服装套图质量检验专家。"},
|
||
{"role": "user", "content": verify_prompt},
|
||
],
|
||
use_model,
|
||
images_b64=[garment_b64, canvas_b64]
|
||
)
|
||
|
||
# ---------- Qwen / OpenAI 路由 ----------
|
||
from openai import OpenAI
|
||
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=[
|
||
{"role": "system", "content": "你是服装套图质量检验专家。"},
|
||
{"role": "user", "content": [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{canvas_b64}"}},
|
||
{"type": "text", "text": verify_prompt},
|
||
]},
|
||
],
|
||
)
|
||
|
||
return completion.choices[0].message.content or ""
|