feat: expand AI workflow support and refresh docs
This commit is contained in:
@@ -1,17 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AI 模型调用层
|
||||
统一管理所有 LLM / Vision / Image 模型的调用逻辑
|
||||
支持 Qwen (DashScope) 和 Gemini (第三方代理) 两套路由
|
||||
统一管理所有 AI 模型的调用逻辑
|
||||
当前默认 provider 为 Ark,配置项保持通用命名,并支持技能化执行策略
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
import json, base64, re, logging
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from app.core.config import settings
|
||||
from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES
|
||||
from app.api.v1.ai_skills import AiSkill, get_ai_skill, resolve_ai_skill
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeAiConfig:
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: str
|
||||
chat_base_url: str = ""
|
||||
vision_base_url: str = ""
|
||||
image_base_url: str = ""
|
||||
source: str = "server"
|
||||
|
||||
# ==================== Prompts ====================
|
||||
|
||||
SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop CEP 插件中。
|
||||
@@ -20,395 +35,634 @@ SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop
|
||||
1. 回答关于 Photoshop 操作和插件使用的问题
|
||||
2. 通过工具直接操作 Photoshop(创建图层、对齐、查看文档信息等)
|
||||
3. 帮用户排查操作中遇到的错误
|
||||
4. **AI 智能套图**:两阶段流程 — 先生成预览确认,再提取套到裁片上
|
||||
4. AI 智能套图:两阶段流程,先生成预览确认,再提取套到裁片上
|
||||
|
||||
## AI 智能套图流程(两阶段)
|
||||
## AI 智能套图流程
|
||||
|
||||
当用户上传成衣图片并要求套图时:
|
||||
|
||||
### 阶段 1 — 识别裁片 + 生成预览(需用户确认)
|
||||
1. 调用 **identify_pieces** — 截取画布并识别每个图层是什么裁片部位(前片、后片、袖子等)
|
||||
2. 告诉用户识别结果(如 "M-1=前片, M-2=后片, M-5=左袖...")
|
||||
3. 调用 **generate_garment_preview** — 会自动在裁片下方标注名称标签(如 "M-1 前片"),然后截取带标签的画布 + 成衣照片一起发给 AI 生成预览
|
||||
4. 预览图显示在聊天中,每个裁片都有标签,**等用户确认 OK 后**进入阶段 2
|
||||
|
||||
### 阶段 2 — 提取花样 + 正式套图(用户确认后)
|
||||
5. 根据 identify_pieces 分析结果中每个裁片的 type 决定处理方式:
|
||||
- solid → 设 color 字段,PS 直接纯色填充
|
||||
- fill_pattern → AI 提取花型铺满
|
||||
- theme_pattern → 底层 PS 纯色填充(设 color)+ 上层 AI 提取主题图案(白底+正片叠底)
|
||||
- mixed_pattern → 底层 AI 提取花型 + 上层 AI 提取主题图案(白底+正片叠底)
|
||||
6. 调用 **extract_and_apply_all_pieces** 执行套图
|
||||
7. 可选:调用 **verify_pattern_result** 验证效果
|
||||
1. 先调用 identify_pieces 识别每个图层是什么裁片部位
|
||||
2. 告诉用户识别结果
|
||||
3. 调用 generate_garment_preview 生成带标签的裁片预览图
|
||||
4. 等用户确认 OK 后,再调用 extract_and_apply_all_pieces 正式套图
|
||||
|
||||
重要:
|
||||
- 阶段 1 完成后必须**等用户说"可以"/"OK"**才执行阶段 2
|
||||
- 用户可能会说"袖子纯色"、"后幅不要花样"等,要根据 identify_pieces 的结果对应到正确的图层名
|
||||
|
||||
重要规则:
|
||||
- 当用户要求执行 PS 操作时,使用工具完成
|
||||
- 执行操作前可以先了解当前文档和图层状态
|
||||
- 用简洁的中文回答,适合在小面板中阅读
|
||||
- 如果工具执行失败,向用户解释原因并建议解决方案
|
||||
- 预览生成后必须等用户确认,不能自己直接进入正式套图
|
||||
- 每个工具最多调用 1 次,除非上次执行失败
|
||||
- 用户要求执行 Photoshop 操作时,优先使用工具完成
|
||||
- 用户如果是在问“不会怎么做 / 为什么不行 / 下一步怎么操作”,要优先结合当前 Photoshop 上下文答疑,不要只给脱离现场的教程
|
||||
- 答疑、查询当前信息、执行操作可以在同一轮里完成,不要把它们人为拆成两种模式
|
||||
- 如果用户的问题和当前文档、图层、选区、参考图有关,优先先查 1 个最相关的上下文工具,再给结论
|
||||
- 如果用户意图已经明确且风险低,可以边解释边执行;如果涉及删除、覆盖、关闭、合并等风险操作,先确认
|
||||
- 不要只讲原理不动手,也不要只闷头执行不解释;默认输出“判断 + 当前检查/执行结果 + 下一步”
|
||||
- 用简洁中文回答,适合在小面板中阅读
|
||||
- 复杂任务先给 2-4 条简短思路,再开始执行
|
||||
- 默认按 观察 -> 判断 -> 执行 -> 验证 的顺序做事
|
||||
- 如果任务包含多步骤,不要一上来同时调很多工具,先完成当前阶段再进入下一阶段
|
||||
- 回答里尽量显式说明“当前阶段”和“下一步”
|
||||
"""
|
||||
|
||||
VISION_PROMPT = """你是一位资深的服装设计分析师,同时也是 DesignerCEP 的 AI 助手。
|
||||
|
||||
当用户发送服装/成衣图片时,请从以下维度进行专业分析:
|
||||
|
||||
1. **服装类别** — 上衣/裤子/裙子/连衣裙/外套/配饰等,细分款式
|
||||
2. **面料分析** — 根据视觉特征推测面料类型(棉、涤纶、丝绸、针织、牛仔、雪纺等),分析面料质感
|
||||
3. **颜色与印花** — 主色调、配色方案、印花/图案类型及工艺(数码印花、丝网印刷、提花等)
|
||||
4. **版型特点** — 修身/宽松/A字/H型等,分析领口、袖型、肩线、腰线、下摆处理
|
||||
5. **工艺细节** — 缝线工艺、拉链/纽扣/暗扣、口袋设计、装饰细节、包边/锁边
|
||||
6. **设计评价** — 设计亮点、风格定位(休闲/正装/运动/时尚等)、目标消费群体
|
||||
7. **改进建议** — 如有可改进之处,给出专业建议
|
||||
请结合图片和用户问题,从服装类别、面料、颜色印花、版型特点、工艺细节、设计评价、改进建议等维度做专业分析。
|
||||
|
||||
规则:
|
||||
- 如果图片不是服装相关,也请尽力分析图片内容并给出有价值的反馈
|
||||
- 如果用户同时提了文字问题,请结合图片和问题一起回答
|
||||
- 用清晰、结构化的中文回答,适合在设计工作中参考
|
||||
- 回答要专业但不啰嗦,突出重点信息
|
||||
- 如果图片不是服装相关,也尽力分析图片内容
|
||||
- 用清晰、结构化的中文回答
|
||||
- 回答要专业但不啰嗦,突出重点
|
||||
"""
|
||||
|
||||
TOOL_BY_NAME = {
|
||||
tool["function"]["name"]: tool
|
||||
for tool in PS_TOOLS
|
||||
if tool.get("type") == "function" and tool.get("function", {}).get("name")
|
||||
}
|
||||
|
||||
|
||||
# ==================== 客户端工厂 ====================
|
||||
|
||||
|
||||
def get_runtime_ai_config(user=None) -> RuntimeAiConfig:
|
||||
user_provider = (getattr(user, "ai_provider", "") or "").strip().lower()
|
||||
user_api_key = (getattr(user, "ai_api_key", "") or "").strip()
|
||||
user_base_url = (getattr(user, "ai_base_url", "") or "").strip()
|
||||
user_chat_base_url = (getattr(user, "ai_chat_base_url", "") or "").strip()
|
||||
user_vision_base_url = (getattr(user, "ai_vision_base_url", "") or "").strip()
|
||||
user_image_base_url = (getattr(user, "ai_image_base_url", "") or "").strip()
|
||||
|
||||
default_base_url = (
|
||||
settings.AI_BASE_URL
|
||||
or settings.ARK_BASE_URL
|
||||
or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
)
|
||||
|
||||
if user_api_key:
|
||||
return RuntimeAiConfig(
|
||||
provider=user_provider or (settings.AI_PROVIDER or "ark").strip().lower(),
|
||||
api_key=user_api_key,
|
||||
base_url=user_base_url or default_base_url,
|
||||
chat_base_url=user_chat_base_url or settings.AI_CHAT_BASE_URL or "",
|
||||
vision_base_url=user_vision_base_url or settings.AI_VISION_BASE_URL or "",
|
||||
image_base_url=user_image_base_url or settings.AI_IMAGE_BASE_URL or "",
|
||||
source="user",
|
||||
)
|
||||
|
||||
return RuntimeAiConfig(
|
||||
provider=(settings.AI_PROVIDER or "ark").strip().lower(),
|
||||
api_key=(
|
||||
settings.AI_API_KEY
|
||||
or settings.ARK_API_KEY
|
||||
or os.getenv("AI_API_KEY", "")
|
||||
or os.getenv("ARK_API_KEY", "")
|
||||
),
|
||||
base_url=default_base_url,
|
||||
chat_base_url=settings.AI_CHAT_BASE_URL or "",
|
||||
vision_base_url=settings.AI_VISION_BASE_URL or "",
|
||||
image_base_url=settings.AI_IMAGE_BASE_URL or "",
|
||||
source="server",
|
||||
)
|
||||
|
||||
def get_ai_provider(runtime_config: Optional[RuntimeAiConfig] = None) -> str:
|
||||
return (runtime_config.provider if runtime_config else settings.AI_PROVIDER or "ark").strip().lower()
|
||||
|
||||
|
||||
def get_ai_api_key(runtime_config: Optional[RuntimeAiConfig] = None) -> str:
|
||||
return (runtime_config.api_key if runtime_config else get_runtime_ai_config().api_key).strip()
|
||||
|
||||
|
||||
def get_ai_base_url(
|
||||
runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat"
|
||||
) -> str:
|
||||
config = runtime_config or get_runtime_ai_config()
|
||||
if capability == "vision":
|
||||
return (config.vision_base_url or config.base_url).strip()
|
||||
if capability == "image":
|
||||
return (config.image_base_url or config.base_url).strip()
|
||||
return (config.chat_base_url or config.base_url).strip()
|
||||
|
||||
|
||||
def has_ai_config(runtime_config: Optional[RuntimeAiConfig] = None) -> bool:
|
||||
return bool(get_ai_api_key(runtime_config))
|
||||
|
||||
|
||||
def get_ai_client(
|
||||
runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat"
|
||||
):
|
||||
"""获取当前 provider 的客户端。"""
|
||||
from volcenginesdkarkruntime import Ark
|
||||
|
||||
provider = get_ai_provider(runtime_config)
|
||||
if provider != "ark":
|
||||
raise ValueError(f"当前暂不支持的 AI_PROVIDER: {provider}")
|
||||
|
||||
api_key = get_ai_api_key(runtime_config)
|
||||
if not api_key:
|
||||
raise ValueError("AI_API_KEY 未配置")
|
||||
|
||||
return Ark(
|
||||
base_url=get_ai_base_url(runtime_config, capability=capability),
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
def is_gemini_model(model_name: str) -> bool:
|
||||
"""兼容旧代码,当前已统一走豆包。"""
|
||||
return False
|
||||
|
||||
|
||||
def _tool_schemas_for_skill(skill: AiSkill) -> List[dict]:
|
||||
if not skill.allowed_tools:
|
||||
return []
|
||||
return [TOOL_BY_NAME[name] for name in skill.allowed_tools if name in TOOL_BY_NAME]
|
||||
|
||||
|
||||
def _numbered_block(items: tuple[str, ...]) -> str:
|
||||
if not items:
|
||||
return "无"
|
||||
return "\n".join(f"{idx + 1}. {item}" for idx, item in enumerate(items))
|
||||
|
||||
|
||||
def _bullet_block(items: tuple[str, ...]) -> str:
|
||||
if not items:
|
||||
return "无"
|
||||
return "\n".join(f"- {item}" for item in items)
|
||||
|
||||
|
||||
def _build_skill_prompt(skill: AiSkill, has_image: bool = False) -> str:
|
||||
workflow = _numbered_block(skill.workflow)
|
||||
guardrails = _bullet_block(skill.guardrails)
|
||||
success_criteria = _bullet_block(skill.success_criteria)
|
||||
triage_questions = _numbered_block(skill.triage_questions)
|
||||
deliverables = _bullet_block(skill.deliverables)
|
||||
execution_notes = _bullet_block(skill.execution_notes)
|
||||
tool_names = (
|
||||
"、".join(TOOL_DISPLAY_NAMES.get(name, name) for name in skill.allowed_tools)
|
||||
if skill.allowed_tools
|
||||
else "当前技能以分析和建议为主,不主动调用 Photoshop 工具。"
|
||||
)
|
||||
origin = skill.origin or "项目内置技能"
|
||||
origin_url = skill.origin_url or "无"
|
||||
upstream_skill = skill.upstream_skill or "无"
|
||||
|
||||
return f"""当前启用技能:{skill.name}({skill.id})
|
||||
技能定位:{skill.description}
|
||||
执行模式:{skill.mode}
|
||||
规划风格:{skill.planning_style}
|
||||
图片提示:{skill.image_hint or '无'}
|
||||
来源:{origin}
|
||||
上游技能:{upstream_skill}
|
||||
来源地址:{origin_url}
|
||||
|
||||
推荐工作流:
|
||||
{workflow}
|
||||
|
||||
预判问题:
|
||||
{triage_questions}
|
||||
|
||||
期望交付:
|
||||
{deliverables}
|
||||
|
||||
执行边界:
|
||||
{guardrails}
|
||||
|
||||
执行备注:
|
||||
{execution_notes}
|
||||
|
||||
成功标准:
|
||||
{success_criteria}
|
||||
|
||||
本技能优先工具:
|
||||
{tool_names}
|
||||
|
||||
本轮补充要求:
|
||||
- 回复保持中文、简洁、能落地。
|
||||
- 涉及具体执行时,先根据当前技能给出简短思路,再调用工具。
|
||||
- 根据用户问题动态混合答疑、查询与执行;如果用户是在求助“怎么做/为什么失败/下一步怎么弄”,不要把答疑和执行拆成两轮。
|
||||
- 能读取当前 Photoshop 上下文时,优先读当前信息后再回答;不要只给泛化教程。
|
||||
- 优先先想清楚计划,再动手执行;如果任务复杂,先向用户同步 2 到 4 条阶段计划。
|
||||
- 每轮优先调用最相关的 1 个工具,避免一口气乱调很多工具。
|
||||
- 如果图片信息不足、文档上下文不足或风险较高,要先说明再操作。
|
||||
- 每完成一个阶段,要用一句中文总结已完成内容,并指向下一步。
|
||||
- {'本轮带有图片,请结合图片内容和技能策略做判断。' if has_image else '本轮没有新图片,优先结合聊天历史和 Photoshop 上下文。'}
|
||||
"""
|
||||
|
||||
|
||||
# ==================== 路由判断 ====================
|
||||
|
||||
def is_gemini_model(model_name: str) -> bool:
|
||||
"""判断是否是 Gemini 模型"""
|
||||
return bool(model_name) and "gemini" in model_name.lower()
|
||||
def _chat_system_prompt(skill: AiSkill, has_image: bool = False) -> str:
|
||||
return f"{SYSTEM_PROMPT}\n\n{_build_skill_prompt(skill, has_image)}"
|
||||
|
||||
|
||||
# ==================== Qwen / OpenAI 兼容调用 ====================
|
||||
def _vision_instructions(skill: AiSkill) -> str:
|
||||
return f"{VISION_PROMPT}\n\n{_build_skill_prompt(skill, has_image=True)}"
|
||||
|
||||
def call_llm_with_tools(messages_history: List[dict], model_override: str = None):
|
||||
"""调用 LLM,支持 function calling"""
|
||||
from openai import OpenAI
|
||||
|
||||
def _messages_with_system(
|
||||
messages_history: List[dict], skill: AiSkill, has_image: bool = False
|
||||
) -> List[dict]:
|
||||
return [{"role": "system", "content": _chat_system_prompt(skill, has_image)}] + messages_history
|
||||
|
||||
|
||||
def _extract_response_text(response) -> str:
|
||||
texts: List[str] = []
|
||||
for output_item in getattr(response, "output", []) or []:
|
||||
for content_item in getattr(output_item, "content", []) or []:
|
||||
if getattr(content_item, "type", "") == "output_text":
|
||||
text_value = getattr(content_item, "text", "")
|
||||
if text_value:
|
||||
texts.append(text_value)
|
||||
return "\n".join(texts).strip()
|
||||
|
||||
|
||||
def _image_data_url(image_base64: str) -> str:
|
||||
return f"data:image/jpeg;base64,{image_base64}"
|
||||
|
||||
|
||||
# ==================== 聊天 / 工具调用 ====================
|
||||
|
||||
|
||||
def call_llm_with_tools(
|
||||
messages_history: List[dict],
|
||||
model_override: str = None,
|
||||
skill_id: Optional[str] = None,
|
||||
has_image: bool = False,
|
||||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||||
):
|
||||
"""调用当前配置的对话模型,支持 function calling。"""
|
||||
use_model = model_override or settings.AI_MODEL
|
||||
|
||||
client = OpenAI(
|
||||
api_key=settings.AI_API_KEY,
|
||||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||||
client = get_ai_client(runtime_config, capability="chat")
|
||||
skill = get_ai_skill(skill_id) or resolve_ai_skill(
|
||||
message=str(messages_history[-1].get("content", "")) if messages_history else "",
|
||||
history_messages=messages_history[:-1],
|
||||
requested_skill_id=skill_id,
|
||||
has_image=has_image,
|
||||
)
|
||||
messages = _messages_with_system(messages_history, skill, has_image)
|
||||
tools = _tool_schemas_for_skill(skill)
|
||||
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages_history
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(f"[LLM] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}")
|
||||
log.info(f"[LLM] skill={skill.id}")
|
||||
log.info(f"[LLM] 消息数量: {len(messages)}")
|
||||
log.info(f"[LLM] 工具数量: {len(tools)}")
|
||||
|
||||
log.info(f"{'='*60}")
|
||||
log.info(f"[LLM] 调用工具模型: {use_model}{' (override)' if model_override else ''}")
|
||||
log.info(f"[LLM] 消息数量: {len(messages)} (system + {len(messages_history)} history)")
|
||||
for i, m in enumerate(messages_history[-5:]):
|
||||
role = m['role']
|
||||
content = m['content'][:120] if m.get('content') else '(empty)'
|
||||
log.info(f"[LLM] history[-{len(messages_history)-i}] {role}: {content}")
|
||||
log.info(f"[LLM] 工具数量: {len(PS_TOOLS)}")
|
||||
request_kwargs = {
|
||||
"model": use_model,
|
||||
"messages": messages,
|
||||
"temperature": 0.3,
|
||||
}
|
||||
if tools:
|
||||
request_kwargs.update(
|
||||
{
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": False,
|
||||
}
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=use_model,
|
||||
messages=messages,
|
||||
tools=PS_TOOLS,
|
||||
tool_choice="auto",
|
||||
)
|
||||
completion = client.chat.completions.create(**request_kwargs)
|
||||
|
||||
choice = completion.choices[0]
|
||||
message = choice.message
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
|
||||
log.info(f"[LLM] 响应 finish_reason={choice.finish_reason}")
|
||||
if message.content:
|
||||
log.info(f"[LLM] 回复文本: {message.content[:150]}...")
|
||||
if message.tool_calls:
|
||||
for tc in message.tool_calls:
|
||||
log.info(f"[LLM] 工具调用: {tc.function.name}({tc.function.arguments[:200]})")
|
||||
else:
|
||||
log.info(f"[LLM] 无工具调用")
|
||||
log.info(f"{'='*60}")
|
||||
|
||||
if message.tool_calls and len(message.tool_calls) > 0:
|
||||
if tool_calls:
|
||||
tool_calls_data = []
|
||||
for tc in message.tool_calls:
|
||||
for tc in tool_calls:
|
||||
args = {}
|
||||
if tc.function.arguments:
|
||||
function = getattr(tc, "function", None)
|
||||
raw_args = getattr(function, "arguments", "") if function else ""
|
||||
if raw_args:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
args = json.loads(raw_args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls_data.append({
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"display_name": TOOL_DISPLAY_NAMES.get(tc.function.name, tc.function.name),
|
||||
"args": args,
|
||||
"status": "pending"
|
||||
})
|
||||
return message.content or "", tool_calls_data
|
||||
function_name = getattr(function, "name", "") if function else ""
|
||||
tool_calls_data.append(
|
||||
{
|
||||
"id": getattr(tc, "id", function_name),
|
||||
"name": function_name,
|
||||
"display_name": TOOL_DISPLAY_NAMES.get(function_name, function_name),
|
||||
"args": args,
|
||||
"status": "pending",
|
||||
}
|
||||
)
|
||||
return getattr(message, "content", "") or "", tool_calls_data, skill.to_public_dict()
|
||||
|
||||
return message.content or "", None
|
||||
return getattr(message, "content", "") or "", None, skill.to_public_dict()
|
||||
|
||||
|
||||
# ==================== Gemini 调用(OpenAI 兼容代理) ====================
|
||||
|
||||
def call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None) -> str:
|
||||
"""调用 Gemini(通过第三方代理,OpenAI 兼容格式)"""
|
||||
from openai import OpenAI as _OpenAI
|
||||
|
||||
if not settings.GEMINI_API_KEY or not settings.GEMINI_BASE_URL:
|
||||
raise ValueError("GEMINI_API_KEY 或 GEMINI_BASE_URL 未配置")
|
||||
|
||||
client = _OpenAI(
|
||||
api_key=settings.GEMINI_API_KEY,
|
||||
base_url=f"{settings.GEMINI_BASE_URL}/v1",
|
||||
)
|
||||
|
||||
messages = []
|
||||
def call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None):
|
||||
"""兼容旧接口,内部统一转到当前视觉模型。"""
|
||||
prompt = ""
|
||||
for msg in messages_history:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if role not in ("system", "user", "assistant"):
|
||||
role = "user"
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
if msg.get("role") == "user" and msg.get("content"):
|
||||
prompt = str(msg["content"])
|
||||
if images_b64:
|
||||
last_user_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i]["role"] == "user":
|
||||
last_user_idx = i
|
||||
break
|
||||
if last_user_idx is not None:
|
||||
text_content = messages[last_user_idx]["content"]
|
||||
multimodal_content = []
|
||||
for img_b64 in images_b64:
|
||||
multimodal_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}
|
||||
})
|
||||
multimodal_content.append({"type": "text", "text": text_content})
|
||||
messages[last_user_idx]["content"] = multimodal_content
|
||||
return call_vision_messages(prompt or "请分析这张图片。", images_b64, model)
|
||||
|
||||
log.info(f"{'='*60}")
|
||||
log.info(f"[Gemini] 调用模型: {model} (OpenAI 兼容)")
|
||||
log.info(f"[Gemini] 消息数: {len(messages)}, 图片数: {len(images_b64) if images_b64 else 0}")
|
||||
|
||||
completion = client.chat.completions.create(model=model, messages=messages)
|
||||
result = completion.choices[0].message.content or ""
|
||||
log.info(f"[Gemini] 回复: {result[:200]}...")
|
||||
return result
|
||||
reply, _, _ = call_llm_with_tools(messages_history, model_override=model)
|
||||
return reply
|
||||
|
||||
|
||||
def call_gemini_with_tools(messages_history: List[dict], model: str) -> tuple:
|
||||
"""用 Gemini 做对话(不支持 function calling)"""
|
||||
full_history = [{"role": "system", "content": SYSTEM_PROMPT}] + messages_history
|
||||
|
||||
log.info(f"{'='*60}")
|
||||
log.info(f"[Gemini] 调用对话模型: {model}")
|
||||
for i, m in enumerate(messages_history[-3:]):
|
||||
log.info(f"[Gemini] history[-{len(messages_history)-i}] {m['role']}: {str(m.get('content',''))[:120]}")
|
||||
|
||||
result = call_gemini(full_history, model)
|
||||
log.info(f"[Gemini] 回复文本: {result[:150]}...")
|
||||
return result, None
|
||||
"""兼容旧接口,内部统一转到当前工具调用。"""
|
||||
return call_llm_with_tools(messages_history, model_override=model)
|
||||
|
||||
|
||||
# ==================== 视觉模型 ====================
|
||||
|
||||
def call_vision_llm(user_message: str, image_base64: str, history: List[dict], model_override: str = None) -> str:
|
||||
"""调用视觉模型分析图片(自动路由 Qwen / Gemini)"""
|
||||
|
||||
def call_vision_messages(
|
||||
user_message: str,
|
||||
images_b64: List[str],
|
||||
model_override: str = None,
|
||||
instructions_override: Optional[str] = None,
|
||||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||||
) -> str:
|
||||
"""调用当前视觉模型分析图片。"""
|
||||
use_model = model_override or settings.AI_VISION_MODEL
|
||||
client = get_ai_client(runtime_config, capability="vision")
|
||||
|
||||
if is_gemini_model(use_model):
|
||||
log.info(f"[Vision] 使用 Gemini 视觉模型: {use_model}")
|
||||
msgs = [{"role": "system", "content": VISION_PROMPT}]
|
||||
for h in history[-10:]:
|
||||
role = h["role"] if h["role"] != "tool" else "user"
|
||||
msgs.append({"role": role, "content": h["content"]})
|
||||
msgs.append({"role": "user", "content": user_message})
|
||||
return call_gemini(msgs, use_model, images_b64=[image_base64])
|
||||
content = []
|
||||
for image_b64 in images_b64:
|
||||
content.append({"type": "input_image", "image_url": _image_data_url(image_b64)})
|
||||
content.append({"type": "input_text", "text": user_message})
|
||||
|
||||
from openai import OpenAI
|
||||
client = OpenAI(api_key=settings.AI_API_KEY, base_url=settings.AI_BASE_URL or "https://api.openai.com/v1")
|
||||
response = client.responses.create(
|
||||
model=use_model,
|
||||
instructions=instructions_override or VISION_PROMPT,
|
||||
input=[{"role": "user", "content": content}],
|
||||
)
|
||||
|
||||
user_content = [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}},
|
||||
{"type": "text", "text": user_message}
|
||||
]
|
||||
messages = [{"role": "system", "content": VISION_PROMPT}]
|
||||
for h in history[-10:]:
|
||||
role = h["role"] if h["role"] != "tool" else "user"
|
||||
messages.append({"role": role, "content": h["content"]})
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
completion = client.chat.completions.create(model=use_model, messages=messages)
|
||||
return completion.choices[0].message.content or ""
|
||||
result = _extract_response_text(response)
|
||||
log.info(
|
||||
f"[Vision] provider={get_ai_provider(runtime_config)} model={use_model} 输出长度: {len(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ==================== 图片编辑/生成模型 ====================
|
||||
def call_vision_llm(
|
||||
user_message: str,
|
||||
image_base64: str,
|
||||
history: List[dict],
|
||||
model_override: str = None,
|
||||
skill_id: Optional[str] = None,
|
||||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||||
) -> str:
|
||||
"""兼容旧接口,调用当前视觉模型。"""
|
||||
skill = get_ai_skill(skill_id) or resolve_ai_skill(
|
||||
message=user_message,
|
||||
history_messages=history,
|
||||
requested_skill_id=skill_id,
|
||||
has_image=True,
|
||||
)
|
||||
combined_message = user_message
|
||||
if history:
|
||||
recent_text = "\n".join(
|
||||
str(item.get("content", "")) for item in history[-5:] if item.get("content")
|
||||
)
|
||||
if recent_text:
|
||||
combined_message = f"历史对话:\n{recent_text}\n\n当前问题:{user_message}"
|
||||
return call_vision_messages(
|
||||
combined_message,
|
||||
[image_base64],
|
||||
model_override,
|
||||
instructions_override=_vision_instructions(skill),
|
||||
runtime_config=runtime_config,
|
||||
)
|
||||
|
||||
def call_image_model(images_b64: List[str], prompt: str, model_override: str = None) -> tuple:
|
||||
"""调用图片编辑/生成模型(自动路由 DashScope / Gemini),返回 (url_or_datauri, description)"""
|
||||
import requests as http_requests
|
||||
|
||||
# ==================== 图片编辑 / 生成 ====================
|
||||
|
||||
|
||||
def call_image_model(
|
||||
images_b64: List[str],
|
||||
prompt: str,
|
||||
model_override: str = None,
|
||||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||||
) -> tuple:
|
||||
"""调用当前图片模型,返回 (url_or_datauri, description)。"""
|
||||
result = call_image_model_batch(
|
||||
images_b64=images_b64,
|
||||
prompt=prompt,
|
||||
model_override=model_override,
|
||||
size="2K",
|
||||
max_images=1,
|
||||
stream=False,
|
||||
runtime_config=runtime_config,
|
||||
)
|
||||
first_image = result["images"][0]
|
||||
return str(first_image["url"]), ""
|
||||
|
||||
|
||||
def _image_response_error_message(payload) -> str:
|
||||
error = getattr(payload, "error", None)
|
||||
return str(getattr(error, "message", "") or "").strip()
|
||||
|
||||
|
||||
def _image_usage_dict(payload) -> Optional[dict]:
|
||||
usage = getattr(payload, "usage", None)
|
||||
if usage is None:
|
||||
return None
|
||||
if hasattr(usage, "model_dump"):
|
||||
return usage.model_dump(mode="json")
|
||||
if isinstance(usage, dict):
|
||||
return usage
|
||||
return None
|
||||
|
||||
|
||||
def _image_item_to_result(image_item, index: int) -> dict:
|
||||
image_url = getattr(image_item, "url", "") or ""
|
||||
image_b64 = getattr(image_item, "b64_json", "") or ""
|
||||
image_size = getattr(image_item, "size", "") or ""
|
||||
|
||||
if image_url:
|
||||
return {"index": index, "url": image_url, "size": image_size}
|
||||
|
||||
if image_b64:
|
||||
padding = (-len(image_b64)) % 4
|
||||
if padding:
|
||||
image_b64 += "=" * padding
|
||||
return {
|
||||
"index": index,
|
||||
"url": f"data:image/png;base64,{image_b64}",
|
||||
"size": image_size,
|
||||
}
|
||||
|
||||
raise ValueError("图片模型返回了空图片结果")
|
||||
|
||||
|
||||
def call_image_model_batch(
|
||||
images_b64: List[str],
|
||||
prompt: str,
|
||||
model_override: str = None,
|
||||
size: str = "2K",
|
||||
max_images: int = 1,
|
||||
stream: Optional[bool] = None,
|
||||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||||
) -> dict:
|
||||
"""调用当前图片模型,支持单图、参考图生图和连续组图。"""
|
||||
from volcenginesdkarkruntime.types.images.images import (
|
||||
SequentialImageGenerationOptions,
|
||||
)
|
||||
|
||||
use_model = model_override or settings.AI_IMAGE_EDIT_MODEL
|
||||
client = get_ai_client(runtime_config, capability="image")
|
||||
image_count = max(1, min(int(max_images or 1), 4))
|
||||
should_stream = (image_count > 1) if stream is None else bool(stream)
|
||||
|
||||
# ---------- Gemini(OpenAI 兼容代理) ----------
|
||||
if is_gemini_model(use_model):
|
||||
log.info(f"[ImageModel] 使用 Gemini 图片模型: {use_model}")
|
||||
if not settings.GEMINI_API_KEY or not settings.GEMINI_BASE_URL:
|
||||
raise ValueError("GEMINI_API_KEY 或 GEMINI_BASE_URL 未配置")
|
||||
|
||||
from openai import OpenAI as _OpenAI
|
||||
client = _OpenAI(api_key=settings.GEMINI_API_KEY, base_url=f"{settings.GEMINI_BASE_URL}/v1")
|
||||
|
||||
content_parts = [{"type": "text", "text": prompt}]
|
||||
for img_b64 in images_b64:
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}})
|
||||
|
||||
log.info(f"[ImageModel] Gemini OpenAI 兼容, 模型: {use_model}")
|
||||
log.info(f"[ImageModel] 图片数: {len(images_b64)}, 各图: {[len(b)//1024 for b in images_b64]}KB")
|
||||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||||
|
||||
completion = client.chat.completions.create(model=use_model, messages=[{"role": "user", "content": content_parts}])
|
||||
result_content = completion.choices[0].message.content or ""
|
||||
log.info(f"[ImageModel] Gemini 回复长度: {len(result_content)} chars")
|
||||
|
||||
# 提取 base64 图片
|
||||
match = re.search(r'!\[.*?\]\((data:image/(\w+);base64,([^)]+))\)', result_content)
|
||||
if not match:
|
||||
match = re.search(r'(data:image/(\w+);base64,([A-Za-z0-9+/=]+))', result_content)
|
||||
if not match:
|
||||
log.warning(f"[ImageModel] Gemini 响应中无图片,前500字: {result_content[:500]}")
|
||||
raise ValueError("Gemini 未返回图片,请检查模型是否支持图片生成")
|
||||
|
||||
img_format = match.group(2)
|
||||
image_b64 = match.group(3)
|
||||
padding = 4 - len(image_b64) % 4
|
||||
if padding != 4:
|
||||
image_b64 += '=' * padding
|
||||
|
||||
log.info(f"[ImageModel] Gemini 返回图片: image/{img_format}, {len(image_b64)//1024}KB")
|
||||
_save_debug_image(base64.b64decode(image_b64), f'gemini_output.{img_format}')
|
||||
|
||||
description = re.sub(r'!\[.*?\]\(data:image/[^)]+\)', '', result_content).strip()
|
||||
return f"data:image/{img_format};base64,{image_b64}", description
|
||||
|
||||
# ---------- DashScope 原生接口 ----------
|
||||
api_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||
content_parts = []
|
||||
for img_b64 in images_b64:
|
||||
content_parts.append({"image": f"data:image/jpeg;base64,{img_b64}"})
|
||||
content_parts.append({"text": prompt})
|
||||
|
||||
payload = {
|
||||
"model": use_model,
|
||||
"input": {"messages": [{"role": "user", "content": content_parts}]},
|
||||
"parameters": {"n": 1, "watermark": False, "prompt_extend": True}
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {settings.AI_API_KEY}"}
|
||||
|
||||
log.info(f"{'='*60}")
|
||||
log.info(f"[ImageModel] DashScope 原生 API, 模型: {use_model}")
|
||||
log.info(f"[ImageModel] 图片数: {len(images_b64)}, 各图: {[len(b)//1024 for b in images_b64]}KB")
|
||||
image_inputs = [_image_data_url(img_b64) for img_b64 in images_b64] or None
|
||||
image_payload = None
|
||||
if image_inputs:
|
||||
image_payload = image_inputs[0] if len(image_inputs) == 1 else image_inputs
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(f"[ImageModel] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}")
|
||||
log.info(f"[ImageModel] 图片数: {len(images_b64)}")
|
||||
log.info(f"[ImageModel] 目标生成张数: {image_count}")
|
||||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||||
|
||||
for idx, img_b64 in enumerate(images_b64):
|
||||
_save_debug_image(base64.b64decode(img_b64), f'input_{idx}.jpg')
|
||||
request_kwargs = {
|
||||
"model": use_model,
|
||||
"prompt": prompt,
|
||||
"image": image_payload,
|
||||
"response_format": "url",
|
||||
"size": size or "2K",
|
||||
"watermark": True,
|
||||
"stream": should_stream,
|
||||
"sequential_image_generation": "auto" if image_count > 1 else "disabled",
|
||||
}
|
||||
if image_payload is None:
|
||||
request_kwargs.pop("image")
|
||||
if image_count > 1:
|
||||
request_kwargs["sequential_image_generation_options"] = (
|
||||
SequentialImageGenerationOptions(max_images=image_count)
|
||||
)
|
||||
|
||||
resp = http_requests.post(api_url, json=payload, headers=headers, timeout=120)
|
||||
if resp.status_code != 200:
|
||||
error_data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||
err_msg = error_data.get("message", resp.text[:300])
|
||||
log.error(f"[ImageModel] API 错误: {resp.status_code} {err_msg}")
|
||||
if "data_inspection_failed" in str(error_data):
|
||||
raise ValueError("图片内容未通过安全审核,请更换图片")
|
||||
raise ValueError(f"图片模型调用失败({resp.status_code}): {err_msg}")
|
||||
if should_stream:
|
||||
stream_response = client.images.generate(**request_kwargs)
|
||||
images: List[dict] = []
|
||||
usage: Optional[dict] = None
|
||||
errors: List[str] = []
|
||||
|
||||
data = resp.json()
|
||||
output = data.get("output", {})
|
||||
choices = output.get("choices", [])
|
||||
if not choices:
|
||||
raise ValueError("模型未返回结果")
|
||||
for event in stream_response:
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
content_list = choices[0].get("message", {}).get("content", [])
|
||||
image_url = None
|
||||
description = ""
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
if "image" in item:
|
||||
image_url = item["image"]
|
||||
elif "text" in item:
|
||||
description += item["text"]
|
||||
event_type = str(getattr(event, "type", "") or "")
|
||||
event_error = _image_response_error_message(event)
|
||||
|
||||
if not image_url:
|
||||
raise ValueError("模型未返回图片")
|
||||
if event_type.endswith("partial_failed"):
|
||||
if event_error:
|
||||
errors.append(event_error)
|
||||
continue
|
||||
|
||||
log.info(f"[ImageModel] 输出图片 URL: {image_url[:120]}...")
|
||||
try:
|
||||
out_resp = http_requests.get(image_url, timeout=60)
|
||||
if out_resp.status_code == 200:
|
||||
_save_debug_image(out_resp.content, 'output.png')
|
||||
except Exception:
|
||||
pass
|
||||
if event_type.endswith("partial_succeeded"):
|
||||
try:
|
||||
images.append(
|
||||
_image_item_to_result(
|
||||
event,
|
||||
index=int(getattr(event, "image_index", len(images))),
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
errors.append(str(exc))
|
||||
continue
|
||||
|
||||
return image_url, description
|
||||
if event_type.endswith("completed"):
|
||||
usage = _image_usage_dict(event)
|
||||
|
||||
if errors and not images:
|
||||
raise ValueError(";".join(errors))
|
||||
if not images:
|
||||
raise ValueError("图片模型未返回结果")
|
||||
|
||||
images.sort(key=lambda item: int(item.get("index", 0)))
|
||||
return {"images": images, "usage": usage, "errors": errors}
|
||||
|
||||
response_kwargs = {
|
||||
"model": use_model,
|
||||
"prompt": prompt,
|
||||
"image": image_payload,
|
||||
"response_format": "url",
|
||||
"size": size or "2K",
|
||||
"stream": False,
|
||||
"watermark": True,
|
||||
"sequential_image_generation": "auto" if image_count > 1 else "disabled",
|
||||
"sequential_image_generation_options": (
|
||||
SequentialImageGenerationOptions(max_images=image_count)
|
||||
if image_count > 1
|
||||
else None
|
||||
),
|
||||
}
|
||||
if image_payload is None:
|
||||
response_kwargs.pop("image")
|
||||
|
||||
response = client.images.generate(**response_kwargs)
|
||||
|
||||
error_message = _image_response_error_message(response)
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
data = getattr(response, "data", None) or []
|
||||
if not data:
|
||||
raise ValueError("图片模型未返回结果")
|
||||
|
||||
images = [_image_item_to_result(item, index=idx) for idx, item in enumerate(data)]
|
||||
return {"images": images, "usage": _image_usage_dict(response), "errors": []}
|
||||
|
||||
|
||||
# ==================== 验证套图效果 ====================
|
||||
|
||||
def verify_pattern_result(garment_b64: str, canvas_b64: str, extra_prompt: str = None, vision_model: str = None) -> str:
|
||||
"""用视觉模型对比原始成衣和套图结果"""
|
||||
use_model = vision_model or settings.AI_VISION_MODEL
|
||||
|
||||
log.info(f"[Verify] 模型: {use_model}, 成衣: {len(garment_b64)//1024}KB, 画布: {len(canvas_b64)//1024}KB")
|
||||
|
||||
def verify_pattern_result(
|
||||
garment_b64: str,
|
||||
canvas_b64: str,
|
||||
extra_prompt: str = None,
|
||||
vision_model: str = None,
|
||||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||||
) -> str:
|
||||
"""用当前视觉模型对比原始成衣和套图结果。"""
|
||||
verify_prompt = (
|
||||
"请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。\n"
|
||||
"验证:1. 花样还原度 2. 裁片覆盖完整度 3. 对齐质量 4. 整体效果。给出评分(1-10)和改进建议。"
|
||||
"请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。"
|
||||
"请从花样还原度、裁片覆盖完整度、对齐质量、整体效果四个方面评分,并给出改进建议。"
|
||||
)
|
||||
if extra_prompt:
|
||||
verify_prompt += f"\n用户补充:{extra_prompt}"
|
||||
|
||||
if is_gemini_model(use_model):
|
||||
return call_gemini(
|
||||
[{"role": "user", "content": "你是服装套图质量检验专家。"}, {"role": "user", "content": verify_prompt}],
|
||||
use_model, images_b64=[garment_b64, canvas_b64]
|
||||
)
|
||||
|
||||
from openai import OpenAI
|
||||
client = OpenAI(api_key=settings.AI_API_KEY, base_url=settings.AI_BASE_URL or "https://api.openai.com/v1")
|
||||
completion = client.chat.completions.create(
|
||||
model=use_model,
|
||||
messages=[
|
||||
{"role": "system", "content": "你是服装套图质量检验专家。"},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{canvas_b64}"}},
|
||||
{"type": "text", "text": verify_prompt},
|
||||
]},
|
||||
],
|
||||
return call_vision_messages(
|
||||
verify_prompt,
|
||||
[garment_b64, canvas_b64],
|
||||
model_override=vision_model or settings.AI_VISION_MODEL,
|
||||
runtime_config=runtime_config,
|
||||
)
|
||||
return completion.choices[0].message.content or ""
|
||||
|
||||
|
||||
# ==================== Mock ====================
|
||||
|
||||
|
||||
def mock_reply(message: str, has_image: bool = False) -> str:
|
||||
"""未配置 API Key 时的模拟回复"""
|
||||
"""未配置 API Key 时的模拟回复。"""
|
||||
if has_image:
|
||||
return "【模拟分析】收到图片。AI 分析功能需要配置 AI_API_KEY。"
|
||||
return "【模拟分析】收到图片。请先在 API 配置页填写你自己的转发 Key。"
|
||||
if "套图" in message:
|
||||
return "套图功能在「参数预设」页面。先选择花样组和裁片组,再添加规则,最后点击生成。"
|
||||
return "你好!我是 DesignerCEP AI 助手。你可以问我关于套图、裁片、对齐等功能的问题。"
|
||||
return "套图功能已就绪,但正式 AI 分析前请先在 API 配置页填写你自己的转发 Key。"
|
||||
return "你好!我是 DesignerCEP AI 助手。请先在 API 配置页填写你的转发地址和 Key,然后我再帮你执行聊天、看图和套图任务。"
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
# ==================== 调试工具 ====================
|
||||
|
||||
|
||||
def _save_debug_image(data: bytes, filename: str):
|
||||
"""保存调试图片到 debug_images 目录"""
|
||||
import os, time
|
||||
debug_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'debug_images')
|
||||
"""保存调试图片到 debug_images 目录。"""
|
||||
import time
|
||||
|
||||
debug_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "debug_images"
|
||||
)
|
||||
os.makedirs(debug_dir, exist_ok=True)
|
||||
ts = int(time.time())
|
||||
path = os.path.join(debug_dir, f'{ts}_{filename}')
|
||||
path = os.path.join(debug_dir, f"{ts}_{filename}")
|
||||
try:
|
||||
with open(path, 'wb') as f:
|
||||
f.write(data)
|
||||
with open(path, "wb") as file_obj:
|
||||
file_obj.write(data)
|
||||
log.info(f"[Debug] 图片已保存: {path}")
|
||||
except Exception as e:
|
||||
log.warning(f"[Debug] 保存失败: {e}")
|
||||
except Exception as exc:
|
||||
log.warning(f"[Debug] 保存失败: {exc}")
|
||||
|
||||
Reference in New Issue
Block a user