415 lines
18 KiB
Python
415 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 模型调用层
|
||
统一管理所有 LLM / Vision / Image 模型的调用逻辑
|
||
支持 Qwen (DashScope) 和 Gemini (第三方代理) 两套路由
|
||
"""
|
||
|
||
from typing import List
|
||
import json, base64, re, logging
|
||
from app.core.config import settings
|
||
from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
# ==================== Prompts ====================
|
||
|
||
SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop CEP 插件中。
|
||
|
||
你的能力:
|
||
1. 回答关于 Photoshop 操作和插件使用的问题
|
||
2. 通过工具直接操作 Photoshop(创建图层、对齐、查看文档信息等)
|
||
3. 帮用户排查操作中遇到的错误
|
||
4. **AI 智能套图**:两阶段流程 — 先生成预览确认,再提取套到裁片上
|
||
|
||
## AI 智能套图流程(两阶段)
|
||
|
||
当用户上传成衣图片并要求套图时:
|
||
|
||
### 阶段 1 — 识别裁片 + 生成预览(需用户确认)
|
||
1. 调用 **identify_pieces** — 截取画布并识别每个图层是什么裁片部位(前片、后片、袖子等)
|
||
2. 告诉用户识别结果(如 "M-1=前片, M-2=后片, M-5=左袖...")
|
||
3. 调用 **generate_garment_preview** — 会自动在裁片下方标注名称标签(如 "M-1 前片"),然后截取带标签的画布 + 成衣照片一起发给 AI 生成预览
|
||
4. 预览图显示在聊天中,每个裁片都有标签,**等用户确认 OK 后**进入阶段 2
|
||
|
||
### 阶段 2 — 提取花样 + 正式套图(用户确认后)
|
||
5. 根据 identify_pieces 分析结果中每个裁片的 type 决定处理方式:
|
||
- solid → 设 color 字段,PS 直接纯色填充
|
||
- fill_pattern → AI 提取花型铺满
|
||
- theme_pattern → 底层 PS 纯色填充(设 color)+ 上层 AI 提取主题图案(白底+正片叠底)
|
||
- mixed_pattern → 底层 AI 提取花型 + 上层 AI 提取主题图案(白底+正片叠底)
|
||
6. 调用 **extract_and_apply_all_pieces** 执行套图
|
||
7. 可选:调用 **verify_pattern_result** 验证效果
|
||
|
||
重要:
|
||
- 阶段 1 完成后必须**等用户说"可以"/"OK"**才执行阶段 2
|
||
- 用户可能会说"袖子纯色"、"后幅不要花样"等,要根据 identify_pieces 的结果对应到正确的图层名
|
||
|
||
重要规则:
|
||
- 当用户要求执行 PS 操作时,使用工具完成
|
||
- 执行操作前可以先了解当前文档和图层状态
|
||
- 用简洁的中文回答,适合在小面板中阅读
|
||
- 如果工具执行失败,向用户解释原因并建议解决方案
|
||
"""
|
||
|
||
VISION_PROMPT = """你是一位资深的服装设计分析师,同时也是 DesignerCEP 的 AI 助手。
|
||
|
||
当用户发送服装/成衣图片时,请从以下维度进行专业分析:
|
||
|
||
1. **服装类别** — 上衣/裤子/裙子/连衣裙/外套/配饰等,细分款式
|
||
2. **面料分析** — 根据视觉特征推测面料类型(棉、涤纶、丝绸、针织、牛仔、雪纺等),分析面料质感
|
||
3. **颜色与印花** — 主色调、配色方案、印花/图案类型及工艺(数码印花、丝网印刷、提花等)
|
||
4. **版型特点** — 修身/宽松/A字/H型等,分析领口、袖型、肩线、腰线、下摆处理
|
||
5. **工艺细节** — 缝线工艺、拉链/纽扣/暗扣、口袋设计、装饰细节、包边/锁边
|
||
6. **设计评价** — 设计亮点、风格定位(休闲/正装/运动/时尚等)、目标消费群体
|
||
7. **改进建议** — 如有可改进之处,给出专业建议
|
||
|
||
规则:
|
||
- 如果图片不是服装相关,也请尽力分析图片内容并给出有价值的反馈
|
||
- 如果用户同时提了文字问题,请结合图片和问题一起回答
|
||
- 用清晰、结构化的中文回答,适合在设计工作中参考
|
||
- 回答要专业但不啰嗦,突出重点信息
|
||
"""
|
||
|
||
|
||
# ==================== 路由判断 ====================
|
||
|
||
def is_gemini_model(model_name: str) -> bool:
|
||
"""判断是否是 Gemini 模型"""
|
||
return bool(model_name) and "gemini" in model_name.lower()
|
||
|
||
|
||
# ==================== Qwen / OpenAI 兼容调用 ====================
|
||
|
||
def call_llm_with_tools(messages_history: List[dict], model_override: str = None):
|
||
"""调用 LLM,支持 function calling"""
|
||
from openai import OpenAI
|
||
|
||
use_model = model_override or settings.AI_MODEL
|
||
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages_history
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[LLM] 调用工具模型: {use_model}{' (override)' if model_override else ''}")
|
||
log.info(f"[LLM] 消息数量: {len(messages)} (system + {len(messages_history)} history)")
|
||
for i, m in enumerate(messages_history[-5:]):
|
||
role = m['role']
|
||
content = m['content'][:120] if m.get('content') else '(empty)'
|
||
log.info(f"[LLM] history[-{len(messages_history)-i}] {role}: {content}")
|
||
log.info(f"[LLM] 工具数量: {len(PS_TOOLS)}")
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=messages,
|
||
tools=PS_TOOLS,
|
||
tool_choice="auto",
|
||
)
|
||
|
||
choice = completion.choices[0]
|
||
message = choice.message
|
||
|
||
log.info(f"[LLM] 响应 finish_reason={choice.finish_reason}")
|
||
if message.content:
|
||
log.info(f"[LLM] 回复文本: {message.content[:150]}...")
|
||
if message.tool_calls:
|
||
for tc in message.tool_calls:
|
||
log.info(f"[LLM] 工具调用: {tc.function.name}({tc.function.arguments[:200]})")
|
||
else:
|
||
log.info(f"[LLM] 无工具调用")
|
||
log.info(f"{'='*60}")
|
||
|
||
if message.tool_calls and len(message.tool_calls) > 0:
|
||
tool_calls_data = []
|
||
for tc in message.tool_calls:
|
||
args = {}
|
||
if tc.function.arguments:
|
||
try:
|
||
args = json.loads(tc.function.arguments)
|
||
except json.JSONDecodeError:
|
||
args = {}
|
||
tool_calls_data.append({
|
||
"id": tc.id,
|
||
"name": tc.function.name,
|
||
"display_name": TOOL_DISPLAY_NAMES.get(tc.function.name, tc.function.name),
|
||
"args": args,
|
||
"status": "pending"
|
||
})
|
||
return message.content or "", tool_calls_data
|
||
|
||
return message.content or "", None
|
||
|
||
|
||
# ==================== Gemini 调用(OpenAI 兼容代理) ====================
|
||
|
||
def call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None) -> str:
|
||
"""调用 Gemini(通过第三方代理,OpenAI 兼容格式)"""
|
||
from openai import OpenAI as _OpenAI
|
||
|
||
if not settings.GEMINI_API_KEY or not settings.GEMINI_BASE_URL:
|
||
raise ValueError("GEMINI_API_KEY 或 GEMINI_BASE_URL 未配置")
|
||
|
||
client = _OpenAI(
|
||
api_key=settings.GEMINI_API_KEY,
|
||
base_url=f"{settings.GEMINI_BASE_URL}/v1",
|
||
)
|
||
|
||
messages = []
|
||
for msg in messages_history:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
if role not in ("system", "user", "assistant"):
|
||
role = "user"
|
||
messages.append({"role": role, "content": content})
|
||
|
||
if images_b64:
|
||
last_user_idx = None
|
||
for i in range(len(messages) - 1, -1, -1):
|
||
if messages[i]["role"] == "user":
|
||
last_user_idx = i
|
||
break
|
||
if last_user_idx is not None:
|
||
text_content = messages[last_user_idx]["content"]
|
||
multimodal_content = []
|
||
for img_b64 in images_b64:
|
||
multimodal_content.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||
})
|
||
multimodal_content.append({"type": "text", "text": text_content})
|
||
messages[last_user_idx]["content"] = multimodal_content
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Gemini] 调用模型: {model} (OpenAI 兼容)")
|
||
log.info(f"[Gemini] 消息数: {len(messages)}, 图片数: {len(images_b64) if images_b64 else 0}")
|
||
|
||
completion = client.chat.completions.create(model=model, messages=messages)
|
||
result = completion.choices[0].message.content or ""
|
||
log.info(f"[Gemini] 回复: {result[:200]}...")
|
||
return result
|
||
|
||
|
||
def call_gemini_with_tools(messages_history: List[dict], model: str) -> tuple:
|
||
"""用 Gemini 做对话(不支持 function calling)"""
|
||
full_history = [{"role": "system", "content": SYSTEM_PROMPT}] + messages_history
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[Gemini] 调用对话模型: {model}")
|
||
for i, m in enumerate(messages_history[-3:]):
|
||
log.info(f"[Gemini] history[-{len(messages_history)-i}] {m['role']}: {str(m.get('content',''))[:120]}")
|
||
|
||
result = call_gemini(full_history, model)
|
||
log.info(f"[Gemini] 回复文本: {result[:150]}...")
|
||
return result, None
|
||
|
||
|
||
# ==================== 视觉模型 ====================
|
||
|
||
def call_vision_llm(user_message: str, image_base64: str, history: List[dict], model_override: str = None) -> str:
|
||
"""调用视觉模型分析图片(自动路由 Qwen / Gemini)"""
|
||
use_model = model_override or settings.AI_VISION_MODEL
|
||
|
||
if is_gemini_model(use_model):
|
||
log.info(f"[Vision] 使用 Gemini 视觉模型: {use_model}")
|
||
msgs = [{"role": "system", "content": VISION_PROMPT}]
|
||
for h in history[-10:]:
|
||
role = h["role"] if h["role"] != "tool" else "user"
|
||
msgs.append({"role": role, "content": h["content"]})
|
||
msgs.append({"role": "user", "content": user_message})
|
||
return call_gemini(msgs, use_model, images_b64=[image_base64])
|
||
|
||
from openai import OpenAI
|
||
client = OpenAI(api_key=settings.AI_API_KEY, base_url=settings.AI_BASE_URL or "https://api.openai.com/v1")
|
||
|
||
user_content = [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}},
|
||
{"type": "text", "text": user_message}
|
||
]
|
||
messages = [{"role": "system", "content": VISION_PROMPT}]
|
||
for h in history[-10:]:
|
||
role = h["role"] if h["role"] != "tool" else "user"
|
||
messages.append({"role": role, "content": h["content"]})
|
||
messages.append({"role": "user", "content": user_content})
|
||
|
||
completion = client.chat.completions.create(model=use_model, messages=messages)
|
||
return completion.choices[0].message.content or ""
|
||
|
||
|
||
# ==================== 图片编辑/生成模型 ====================
|
||
|
||
def call_image_model(images_b64: List[str], prompt: str, model_override: str = None) -> tuple:
|
||
"""调用图片编辑/生成模型(自动路由 DashScope / Gemini),返回 (url_or_datauri, description)"""
|
||
import requests as http_requests
|
||
|
||
use_model = model_override or settings.AI_IMAGE_EDIT_MODEL
|
||
|
||
# ---------- Gemini(OpenAI 兼容代理) ----------
|
||
if is_gemini_model(use_model):
|
||
log.info(f"[ImageModel] 使用 Gemini 图片模型: {use_model}")
|
||
if not settings.GEMINI_API_KEY or not settings.GEMINI_BASE_URL:
|
||
raise ValueError("GEMINI_API_KEY 或 GEMINI_BASE_URL 未配置")
|
||
|
||
from openai import OpenAI as _OpenAI
|
||
client = _OpenAI(api_key=settings.GEMINI_API_KEY, base_url=f"{settings.GEMINI_BASE_URL}/v1")
|
||
|
||
content_parts = [{"type": "text", "text": prompt}]
|
||
for img_b64 in images_b64:
|
||
content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}})
|
||
|
||
log.info(f"[ImageModel] Gemini OpenAI 兼容, 模型: {use_model}")
|
||
log.info(f"[ImageModel] 图片数: {len(images_b64)}, 各图: {[len(b)//1024 for b in images_b64]}KB")
|
||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||
|
||
completion = client.chat.completions.create(model=use_model, messages=[{"role": "user", "content": content_parts}])
|
||
result_content = completion.choices[0].message.content or ""
|
||
log.info(f"[ImageModel] Gemini 回复长度: {len(result_content)} chars")
|
||
|
||
# 提取 base64 图片
|
||
match = re.search(r'!\[.*?\]\((data:image/(\w+);base64,([^)]+))\)', result_content)
|
||
if not match:
|
||
match = re.search(r'(data:image/(\w+);base64,([A-Za-z0-9+/=]+))', result_content)
|
||
if not match:
|
||
log.warning(f"[ImageModel] Gemini 响应中无图片,前500字: {result_content[:500]}")
|
||
raise ValueError("Gemini 未返回图片,请检查模型是否支持图片生成")
|
||
|
||
img_format = match.group(2)
|
||
image_b64 = match.group(3)
|
||
padding = 4 - len(image_b64) % 4
|
||
if padding != 4:
|
||
image_b64 += '=' * padding
|
||
|
||
log.info(f"[ImageModel] Gemini 返回图片: image/{img_format}, {len(image_b64)//1024}KB")
|
||
_save_debug_image(base64.b64decode(image_b64), f'gemini_output.{img_format}')
|
||
|
||
description = re.sub(r'!\[.*?\]\(data:image/[^)]+\)', '', result_content).strip()
|
||
return f"data:image/{img_format};base64,{image_b64}", description
|
||
|
||
# ---------- DashScope 原生接口 ----------
|
||
api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||
content_parts = []
|
||
for img_b64 in images_b64:
|
||
content_parts.append({"image": f"data:image/jpeg;base64,{img_b64}"})
|
||
content_parts.append({"text": prompt})
|
||
|
||
payload = {
|
||
"model": use_model,
|
||
"input": {"messages": [{"role": "user", "content": content_parts}]},
|
||
"parameters": {"n": 1, "watermark": False, "prompt_extend": True}
|
||
}
|
||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {settings.AI_API_KEY}"}
|
||
|
||
log.info(f"{'='*60}")
|
||
log.info(f"[ImageModel] DashScope 原生 API, 模型: {use_model}")
|
||
log.info(f"[ImageModel] 图片数: {len(images_b64)}, 各图: {[len(b)//1024 for b in images_b64]}KB")
|
||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||
|
||
for idx, img_b64 in enumerate(images_b64):
|
||
_save_debug_image(base64.b64decode(img_b64), f'input_{idx}.jpg')
|
||
|
||
resp = http_requests.post(api_url, json=payload, headers=headers, timeout=120)
|
||
if resp.status_code != 200:
|
||
error_data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||
err_msg = error_data.get("message", resp.text[:300])
|
||
log.error(f"[ImageModel] API 错误: {resp.status_code} {err_msg}")
|
||
if "data_inspection_failed" in str(error_data):
|
||
raise ValueError("图片内容未通过安全审核,请更换图片")
|
||
raise ValueError(f"图片模型调用失败({resp.status_code}): {err_msg}")
|
||
|
||
data = resp.json()
|
||
output = data.get("output", {})
|
||
choices = output.get("choices", [])
|
||
if not choices:
|
||
raise ValueError("模型未返回结果")
|
||
|
||
content_list = choices[0].get("message", {}).get("content", [])
|
||
image_url = None
|
||
description = ""
|
||
for item in content_list:
|
||
if isinstance(item, dict):
|
||
if "image" in item:
|
||
image_url = item["image"]
|
||
elif "text" in item:
|
||
description += item["text"]
|
||
|
||
if not image_url:
|
||
raise ValueError("模型未返回图片")
|
||
|
||
log.info(f"[ImageModel] 输出图片 URL: {image_url[:120]}...")
|
||
try:
|
||
out_resp = http_requests.get(image_url, timeout=60)
|
||
if out_resp.status_code == 200:
|
||
_save_debug_image(out_resp.content, 'output.png')
|
||
except Exception:
|
||
pass
|
||
|
||
return image_url, description
|
||
|
||
|
||
# ==================== 验证套图效果 ====================
|
||
|
||
def verify_pattern_result(garment_b64: str, canvas_b64: str, extra_prompt: str = None, vision_model: str = None) -> str:
|
||
"""用视觉模型对比原始成衣和套图结果"""
|
||
use_model = vision_model or settings.AI_VISION_MODEL
|
||
|
||
log.info(f"[Verify] 模型: {use_model}, 成衣: {len(garment_b64)//1024}KB, 画布: {len(canvas_b64)//1024}KB")
|
||
|
||
verify_prompt = (
|
||
"请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。\n"
|
||
"验证:1. 花样还原度 2. 裁片覆盖完整度 3. 对齐质量 4. 整体效果。给出评分(1-10)和改进建议。"
|
||
)
|
||
if extra_prompt:
|
||
verify_prompt += f"\n用户补充:{extra_prompt}"
|
||
|
||
if is_gemini_model(use_model):
|
||
return call_gemini(
|
||
[{"role": "user", "content": "你是服装套图质量检验专家。"}, {"role": "user", "content": verify_prompt}],
|
||
use_model, images_b64=[garment_b64, canvas_b64]
|
||
)
|
||
|
||
from openai import OpenAI
|
||
client = OpenAI(api_key=settings.AI_API_KEY, base_url=settings.AI_BASE_URL or "https://api.openai.com/v1")
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=[
|
||
{"role": "system", "content": "你是服装套图质量检验专家。"},
|
||
{"role": "user", "content": [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{canvas_b64}"}},
|
||
{"type": "text", "text": verify_prompt},
|
||
]},
|
||
],
|
||
)
|
||
return completion.choices[0].message.content or ""
|
||
|
||
|
||
# ==================== Mock ====================
|
||
|
||
def mock_reply(message: str, has_image: bool = False) -> str:
|
||
"""未配置 API Key 时的模拟回复"""
|
||
if has_image:
|
||
return "【模拟分析】收到图片。AI 分析功能需要配置 AI_API_KEY。"
|
||
if "套图" in message:
|
||
return "套图功能在「参数预设」页面。先选择花样组和裁片组,再添加规则,最后点击生成。"
|
||
return "你好!我是 DesignerCEP AI 助手。你可以问我关于套图、裁片、对齐等功能的问题。"
|
||
|
||
|
||
# ==================== 工具函数 ====================
|
||
|
||
def _save_debug_image(data: bytes, filename: str):
|
||
"""保存调试图片到 debug_images 目录"""
|
||
import os, time
|
||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||
os.makedirs(debug_dir, exist_ok=True)
|
||
ts = int(time.time())
|
||
path = os.path.join(debug_dir, f'{ts}_{filename}')
|
||
try:
|
||
with open(path, 'wb') as f:
|
||
f.write(data)
|
||
log.info(f"[Debug] 图片已保存: {path}")
|
||
except Exception as e:
|
||
log.warning(f"[Debug] 保存失败: {e}")
|