669 lines
24 KiB
Python
669 lines
24 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
AI 模型调用层
|
||
统一管理所有 AI 模型的调用逻辑
|
||
当前默认 provider 为 Ark,配置项保持通用命名,并支持技能化执行策略
|
||
"""
|
||
|
||
from typing import List, Optional
|
||
from dataclasses import dataclass
|
||
import json
|
||
import logging
|
||
import os
|
||
from app.core.config import settings
|
||
from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES
|
||
from app.api.v1.ai_skills import AiSkill, get_ai_skill, resolve_ai_skill
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class RuntimeAiConfig:
|
||
provider: str
|
||
api_key: str
|
||
base_url: str
|
||
chat_base_url: str = ""
|
||
vision_base_url: str = ""
|
||
image_base_url: str = ""
|
||
source: str = "server"
|
||
|
||
# ==================== Prompts ====================
|
||
|
||
SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop CEP 插件中。
|
||
|
||
你的能力:
|
||
1. 回答关于 Photoshop 操作和插件使用的问题
|
||
2. 通过工具直接操作 Photoshop(创建图层、对齐、查看文档信息等)
|
||
3. 帮用户排查操作中遇到的错误
|
||
4. AI 智能套图:两阶段流程,先生成预览确认,再提取套到裁片上
|
||
|
||
## AI 智能套图流程
|
||
|
||
当用户上传成衣图片并要求套图时:
|
||
1. 先调用 identify_pieces 识别每个图层是什么裁片部位
|
||
2. 告诉用户识别结果
|
||
3. 调用 generate_garment_preview 生成带标签的裁片预览图
|
||
4. 等用户确认 OK 后,再调用 extract_and_apply_all_pieces 正式套图
|
||
|
||
重要:
|
||
- 预览生成后必须等用户确认,不能自己直接进入正式套图
|
||
- 每个工具最多调用 1 次,除非上次执行失败
|
||
- 用户要求执行 Photoshop 操作时,优先使用工具完成
|
||
- 用户如果是在问“不会怎么做 / 为什么不行 / 下一步怎么操作”,要优先结合当前 Photoshop 上下文答疑,不要只给脱离现场的教程
|
||
- 答疑、查询当前信息、执行操作可以在同一轮里完成,不要把它们人为拆成两种模式
|
||
- 如果用户的问题和当前文档、图层、选区、参考图有关,优先先查 1 个最相关的上下文工具,再给结论
|
||
- 如果用户意图已经明确且风险低,可以边解释边执行;如果涉及删除、覆盖、关闭、合并等风险操作,先确认
|
||
- 不要只讲原理不动手,也不要只闷头执行不解释;默认输出“判断 + 当前检查/执行结果 + 下一步”
|
||
- 用简洁中文回答,适合在小面板中阅读
|
||
- 复杂任务先给 2-4 条简短思路,再开始执行
|
||
- 默认按 观察 -> 判断 -> 执行 -> 验证 的顺序做事
|
||
- 如果任务包含多步骤,不要一上来同时调很多工具,先完成当前阶段再进入下一阶段
|
||
- 回答里尽量显式说明“当前阶段”和“下一步”
|
||
"""
|
||
|
||
VISION_PROMPT = """你是一位资深的服装设计分析师,同时也是 DesignerCEP 的 AI 助手。
|
||
|
||
请结合图片和用户问题,从服装类别、面料、颜色印花、版型特点、工艺细节、设计评价、改进建议等维度做专业分析。
|
||
|
||
规则:
|
||
- 如果图片不是服装相关,也尽力分析图片内容
|
||
- 用清晰、结构化的中文回答
|
||
- 回答要专业但不啰嗦,突出重点
|
||
"""
|
||
|
||
TOOL_BY_NAME = {
|
||
tool["function"]["name"]: tool
|
||
for tool in PS_TOOLS
|
||
if tool.get("type") == "function" and tool.get("function", {}).get("name")
|
||
}
|
||
|
||
|
||
# ==================== 客户端工厂 ====================
|
||
|
||
|
||
def get_runtime_ai_config(user=None) -> RuntimeAiConfig:
|
||
user_provider = (getattr(user, "ai_provider", "") or "").strip().lower()
|
||
user_api_key = (getattr(user, "ai_api_key", "") or "").strip()
|
||
user_base_url = (getattr(user, "ai_base_url", "") or "").strip()
|
||
user_chat_base_url = (getattr(user, "ai_chat_base_url", "") or "").strip()
|
||
user_vision_base_url = (getattr(user, "ai_vision_base_url", "") or "").strip()
|
||
user_image_base_url = (getattr(user, "ai_image_base_url", "") or "").strip()
|
||
|
||
default_base_url = (
|
||
settings.AI_BASE_URL
|
||
or settings.ARK_BASE_URL
|
||
or "https://ark.cn-beijing.volces.com/api/v3"
|
||
)
|
||
|
||
if user_api_key:
|
||
return RuntimeAiConfig(
|
||
provider=user_provider or (settings.AI_PROVIDER or "ark").strip().lower(),
|
||
api_key=user_api_key,
|
||
base_url=user_base_url or default_base_url,
|
||
chat_base_url=user_chat_base_url or settings.AI_CHAT_BASE_URL or "",
|
||
vision_base_url=user_vision_base_url or settings.AI_VISION_BASE_URL or "",
|
||
image_base_url=user_image_base_url or settings.AI_IMAGE_BASE_URL or "",
|
||
source="user",
|
||
)
|
||
|
||
return RuntimeAiConfig(
|
||
provider=(settings.AI_PROVIDER or "ark").strip().lower(),
|
||
api_key=(
|
||
settings.AI_API_KEY
|
||
or settings.ARK_API_KEY
|
||
or os.getenv("AI_API_KEY", "")
|
||
or os.getenv("ARK_API_KEY", "")
|
||
),
|
||
base_url=default_base_url,
|
||
chat_base_url=settings.AI_CHAT_BASE_URL or "",
|
||
vision_base_url=settings.AI_VISION_BASE_URL or "",
|
||
image_base_url=settings.AI_IMAGE_BASE_URL or "",
|
||
source="server",
|
||
)
|
||
|
||
def get_ai_provider(runtime_config: Optional[RuntimeAiConfig] = None) -> str:
|
||
return (runtime_config.provider if runtime_config else settings.AI_PROVIDER or "ark").strip().lower()
|
||
|
||
|
||
def get_ai_api_key(runtime_config: Optional[RuntimeAiConfig] = None) -> str:
|
||
return (runtime_config.api_key if runtime_config else get_runtime_ai_config().api_key).strip()
|
||
|
||
|
||
def get_ai_base_url(
|
||
runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat"
|
||
) -> str:
|
||
config = runtime_config or get_runtime_ai_config()
|
||
if capability == "vision":
|
||
return (config.vision_base_url or config.base_url).strip()
|
||
if capability == "image":
|
||
return (config.image_base_url or config.base_url).strip()
|
||
return (config.chat_base_url or config.base_url).strip()
|
||
|
||
|
||
def has_ai_config(runtime_config: Optional[RuntimeAiConfig] = None) -> bool:
|
||
return bool(get_ai_api_key(runtime_config))
|
||
|
||
|
||
def get_ai_client(
|
||
runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat"
|
||
):
|
||
"""获取当前 provider 的客户端。"""
|
||
from volcenginesdkarkruntime import Ark
|
||
|
||
provider = get_ai_provider(runtime_config)
|
||
if provider != "ark":
|
||
raise ValueError(f"当前暂不支持的 AI_PROVIDER: {provider}")
|
||
|
||
api_key = get_ai_api_key(runtime_config)
|
||
if not api_key:
|
||
raise ValueError("AI_API_KEY 未配置")
|
||
|
||
return Ark(
|
||
base_url=get_ai_base_url(runtime_config, capability=capability),
|
||
api_key=api_key,
|
||
)
|
||
|
||
|
||
def is_gemini_model(model_name: str) -> bool:
|
||
"""兼容旧代码,当前已统一走豆包。"""
|
||
return False
|
||
|
||
|
||
def _tool_schemas_for_skill(skill: AiSkill) -> List[dict]:
|
||
if not skill.allowed_tools:
|
||
return []
|
||
return [TOOL_BY_NAME[name] for name in skill.allowed_tools if name in TOOL_BY_NAME]
|
||
|
||
|
||
def _numbered_block(items: tuple[str, ...]) -> str:
|
||
if not items:
|
||
return "无"
|
||
return "\n".join(f"{idx + 1}. {item}" for idx, item in enumerate(items))
|
||
|
||
|
||
def _bullet_block(items: tuple[str, ...]) -> str:
|
||
if not items:
|
||
return "无"
|
||
return "\n".join(f"- {item}" for item in items)
|
||
|
||
|
||
def _build_skill_prompt(skill: AiSkill, has_image: bool = False) -> str:
|
||
workflow = _numbered_block(skill.workflow)
|
||
guardrails = _bullet_block(skill.guardrails)
|
||
success_criteria = _bullet_block(skill.success_criteria)
|
||
triage_questions = _numbered_block(skill.triage_questions)
|
||
deliverables = _bullet_block(skill.deliverables)
|
||
execution_notes = _bullet_block(skill.execution_notes)
|
||
tool_names = (
|
||
"、".join(TOOL_DISPLAY_NAMES.get(name, name) for name in skill.allowed_tools)
|
||
if skill.allowed_tools
|
||
else "当前技能以分析和建议为主,不主动调用 Photoshop 工具。"
|
||
)
|
||
origin = skill.origin or "项目内置技能"
|
||
origin_url = skill.origin_url or "无"
|
||
upstream_skill = skill.upstream_skill or "无"
|
||
|
||
return f"""当前启用技能:{skill.name}({skill.id})
|
||
技能定位:{skill.description}
|
||
执行模式:{skill.mode}
|
||
规划风格:{skill.planning_style}
|
||
图片提示:{skill.image_hint or '无'}
|
||
来源:{origin}
|
||
上游技能:{upstream_skill}
|
||
来源地址:{origin_url}
|
||
|
||
推荐工作流:
|
||
{workflow}
|
||
|
||
预判问题:
|
||
{triage_questions}
|
||
|
||
期望交付:
|
||
{deliverables}
|
||
|
||
执行边界:
|
||
{guardrails}
|
||
|
||
执行备注:
|
||
{execution_notes}
|
||
|
||
成功标准:
|
||
{success_criteria}
|
||
|
||
本技能优先工具:
|
||
{tool_names}
|
||
|
||
本轮补充要求:
|
||
- 回复保持中文、简洁、能落地。
|
||
- 涉及具体执行时,先根据当前技能给出简短思路,再调用工具。
|
||
- 根据用户问题动态混合答疑、查询与执行;如果用户是在求助“怎么做/为什么失败/下一步怎么弄”,不要把答疑和执行拆成两轮。
|
||
- 能读取当前 Photoshop 上下文时,优先读当前信息后再回答;不要只给泛化教程。
|
||
- 优先先想清楚计划,再动手执行;如果任务复杂,先向用户同步 2 到 4 条阶段计划。
|
||
- 每轮优先调用最相关的 1 个工具,避免一口气乱调很多工具。
|
||
- 如果图片信息不足、文档上下文不足或风险较高,要先说明再操作。
|
||
- 每完成一个阶段,要用一句中文总结已完成内容,并指向下一步。
|
||
- {'本轮带有图片,请结合图片内容和技能策略做判断。' if has_image else '本轮没有新图片,优先结合聊天历史和 Photoshop 上下文。'}
|
||
"""
|
||
|
||
|
||
def _chat_system_prompt(skill: AiSkill, has_image: bool = False) -> str:
|
||
return f"{SYSTEM_PROMPT}\n\n{_build_skill_prompt(skill, has_image)}"
|
||
|
||
|
||
def _vision_instructions(skill: AiSkill) -> str:
|
||
return f"{VISION_PROMPT}\n\n{_build_skill_prompt(skill, has_image=True)}"
|
||
|
||
|
||
def _messages_with_system(
|
||
messages_history: List[dict], skill: AiSkill, has_image: bool = False
|
||
) -> List[dict]:
|
||
return [{"role": "system", "content": _chat_system_prompt(skill, has_image)}] + messages_history
|
||
|
||
|
||
def _extract_response_text(response) -> str:
|
||
texts: List[str] = []
|
||
for output_item in getattr(response, "output", []) or []:
|
||
for content_item in getattr(output_item, "content", []) or []:
|
||
if getattr(content_item, "type", "") == "output_text":
|
||
text_value = getattr(content_item, "text", "")
|
||
if text_value:
|
||
texts.append(text_value)
|
||
return "\n".join(texts).strip()
|
||
|
||
|
||
def _image_data_url(image_base64: str) -> str:
|
||
return f"data:image/jpeg;base64,{image_base64}"
|
||
|
||
|
||
# ==================== 聊天 / 工具调用 ====================
|
||
|
||
|
||
def call_llm_with_tools(
|
||
messages_history: List[dict],
|
||
model_override: str = None,
|
||
skill_id: Optional[str] = None,
|
||
has_image: bool = False,
|
||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||
):
|
||
"""调用当前配置的对话模型,支持 function calling。"""
|
||
use_model = model_override or settings.AI_MODEL
|
||
client = get_ai_client(runtime_config, capability="chat")
|
||
skill = get_ai_skill(skill_id) or resolve_ai_skill(
|
||
message=str(messages_history[-1].get("content", "")) if messages_history else "",
|
||
history_messages=messages_history[:-1],
|
||
requested_skill_id=skill_id,
|
||
has_image=has_image,
|
||
)
|
||
messages = _messages_with_system(messages_history, skill, has_image)
|
||
tools = _tool_schemas_for_skill(skill)
|
||
|
||
log.info(f"{'=' * 60}")
|
||
log.info(f"[LLM] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}")
|
||
log.info(f"[LLM] skill={skill.id}")
|
||
log.info(f"[LLM] 消息数量: {len(messages)}")
|
||
log.info(f"[LLM] 工具数量: {len(tools)}")
|
||
|
||
request_kwargs = {
|
||
"model": use_model,
|
||
"messages": messages,
|
||
"temperature": 0.3,
|
||
}
|
||
if tools:
|
||
request_kwargs.update(
|
||
{
|
||
"tools": tools,
|
||
"tool_choice": "auto",
|
||
"parallel_tool_calls": False,
|
||
}
|
||
)
|
||
|
||
completion = client.chat.completions.create(**request_kwargs)
|
||
|
||
choice = completion.choices[0]
|
||
message = choice.message
|
||
tool_calls = getattr(message, "tool_calls", None) or []
|
||
|
||
if tool_calls:
|
||
tool_calls_data = []
|
||
for tc in tool_calls:
|
||
args = {}
|
||
function = getattr(tc, "function", None)
|
||
raw_args = getattr(function, "arguments", "") if function else ""
|
||
if raw_args:
|
||
try:
|
||
args = json.loads(raw_args)
|
||
except json.JSONDecodeError:
|
||
args = {}
|
||
function_name = getattr(function, "name", "") if function else ""
|
||
tool_calls_data.append(
|
||
{
|
||
"id": getattr(tc, "id", function_name),
|
||
"name": function_name,
|
||
"display_name": TOOL_DISPLAY_NAMES.get(function_name, function_name),
|
||
"args": args,
|
||
"status": "pending",
|
||
}
|
||
)
|
||
return getattr(message, "content", "") or "", tool_calls_data, skill.to_public_dict()
|
||
|
||
return getattr(message, "content", "") or "", None, skill.to_public_dict()
|
||
|
||
|
||
def call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None):
|
||
"""兼容旧接口,内部统一转到当前视觉模型。"""
|
||
prompt = ""
|
||
for msg in messages_history:
|
||
if msg.get("role") == "user" and msg.get("content"):
|
||
prompt = str(msg["content"])
|
||
if images_b64:
|
||
return call_vision_messages(prompt or "请分析这张图片。", images_b64, model)
|
||
|
||
reply, _, _ = call_llm_with_tools(messages_history, model_override=model)
|
||
return reply
|
||
|
||
|
||
def call_gemini_with_tools(messages_history: List[dict], model: str) -> tuple:
|
||
"""兼容旧接口,内部统一转到当前工具调用。"""
|
||
return call_llm_with_tools(messages_history, model_override=model)
|
||
|
||
|
||
# ==================== 视觉模型 ====================
|
||
|
||
|
||
def call_vision_messages(
|
||
user_message: str,
|
||
images_b64: List[str],
|
||
model_override: str = None,
|
||
instructions_override: Optional[str] = None,
|
||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||
) -> str:
|
||
"""调用当前视觉模型分析图片。"""
|
||
use_model = model_override or settings.AI_VISION_MODEL
|
||
client = get_ai_client(runtime_config, capability="vision")
|
||
|
||
content = []
|
||
for image_b64 in images_b64:
|
||
content.append({"type": "input_image", "image_url": _image_data_url(image_b64)})
|
||
content.append({"type": "input_text", "text": user_message})
|
||
|
||
response = client.responses.create(
|
||
model=use_model,
|
||
instructions=instructions_override or VISION_PROMPT,
|
||
input=[{"role": "user", "content": content}],
|
||
)
|
||
|
||
result = _extract_response_text(response)
|
||
log.info(
|
||
f"[Vision] provider={get_ai_provider(runtime_config)} model={use_model} 输出长度: {len(result)}"
|
||
)
|
||
return result
|
||
|
||
|
||
def call_vision_llm(
|
||
user_message: str,
|
||
image_base64: str,
|
||
history: List[dict],
|
||
model_override: str = None,
|
||
skill_id: Optional[str] = None,
|
||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||
) -> str:
|
||
"""兼容旧接口,调用当前视觉模型。"""
|
||
skill = get_ai_skill(skill_id) or resolve_ai_skill(
|
||
message=user_message,
|
||
history_messages=history,
|
||
requested_skill_id=skill_id,
|
||
has_image=True,
|
||
)
|
||
combined_message = user_message
|
||
if history:
|
||
recent_text = "\n".join(
|
||
str(item.get("content", "")) for item in history[-5:] if item.get("content")
|
||
)
|
||
if recent_text:
|
||
combined_message = f"历史对话:\n{recent_text}\n\n当前问题:{user_message}"
|
||
return call_vision_messages(
|
||
combined_message,
|
||
[image_base64],
|
||
model_override,
|
||
instructions_override=_vision_instructions(skill),
|
||
runtime_config=runtime_config,
|
||
)
|
||
|
||
|
||
# ==================== 图片编辑 / 生成 ====================
|
||
|
||
|
||
def call_image_model(
|
||
images_b64: List[str],
|
||
prompt: str,
|
||
model_override: str = None,
|
||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||
) -> tuple:
|
||
"""调用当前图片模型,返回 (url_or_datauri, description)。"""
|
||
result = call_image_model_batch(
|
||
images_b64=images_b64,
|
||
prompt=prompt,
|
||
model_override=model_override,
|
||
size="2K",
|
||
max_images=1,
|
||
stream=False,
|
||
runtime_config=runtime_config,
|
||
)
|
||
first_image = result["images"][0]
|
||
return str(first_image["url"]), ""
|
||
|
||
|
||
def _image_response_error_message(payload) -> str:
|
||
error = getattr(payload, "error", None)
|
||
return str(getattr(error, "message", "") or "").strip()
|
||
|
||
|
||
def _image_usage_dict(payload) -> Optional[dict]:
|
||
usage = getattr(payload, "usage", None)
|
||
if usage is None:
|
||
return None
|
||
if hasattr(usage, "model_dump"):
|
||
return usage.model_dump(mode="json")
|
||
if isinstance(usage, dict):
|
||
return usage
|
||
return None
|
||
|
||
|
||
def _image_item_to_result(image_item, index: int) -> dict:
|
||
image_url = getattr(image_item, "url", "") or ""
|
||
image_b64 = getattr(image_item, "b64_json", "") or ""
|
||
image_size = getattr(image_item, "size", "") or ""
|
||
|
||
if image_url:
|
||
return {"index": index, "url": image_url, "size": image_size}
|
||
|
||
if image_b64:
|
||
padding = (-len(image_b64)) % 4
|
||
if padding:
|
||
image_b64 += "=" * padding
|
||
return {
|
||
"index": index,
|
||
"url": f"data:image/png;base64,{image_b64}",
|
||
"size": image_size,
|
||
}
|
||
|
||
raise ValueError("图片模型返回了空图片结果")
|
||
|
||
|
||
def call_image_model_batch(
|
||
images_b64: List[str],
|
||
prompt: str,
|
||
model_override: str = None,
|
||
size: str = "2K",
|
||
max_images: int = 1,
|
||
stream: Optional[bool] = None,
|
||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||
) -> dict:
|
||
"""调用当前图片模型,支持单图、参考图生图和连续组图。"""
|
||
from volcenginesdkarkruntime.types.images.images import (
|
||
SequentialImageGenerationOptions,
|
||
)
|
||
|
||
use_model = model_override or settings.AI_IMAGE_EDIT_MODEL
|
||
client = get_ai_client(runtime_config, capability="image")
|
||
image_count = max(1, min(int(max_images or 1), 4))
|
||
should_stream = (image_count > 1) if stream is None else bool(stream)
|
||
|
||
image_inputs = [_image_data_url(img_b64) for img_b64 in images_b64] or None
|
||
image_payload = None
|
||
if image_inputs:
|
||
image_payload = image_inputs[0] if len(image_inputs) == 1 else image_inputs
|
||
log.info(f"{'=' * 60}")
|
||
log.info(f"[ImageModel] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}")
|
||
log.info(f"[ImageModel] 图片数: {len(images_b64)}")
|
||
log.info(f"[ImageModel] 目标生成张数: {image_count}")
|
||
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
|
||
|
||
request_kwargs = {
|
||
"model": use_model,
|
||
"prompt": prompt,
|
||
"image": image_payload,
|
||
"response_format": "url",
|
||
"size": size or "2K",
|
||
"watermark": True,
|
||
"stream": should_stream,
|
||
"sequential_image_generation": "auto" if image_count > 1 else "disabled",
|
||
}
|
||
if image_payload is None:
|
||
request_kwargs.pop("image")
|
||
if image_count > 1:
|
||
request_kwargs["sequential_image_generation_options"] = (
|
||
SequentialImageGenerationOptions(max_images=image_count)
|
||
)
|
||
|
||
if should_stream:
|
||
stream_response = client.images.generate(**request_kwargs)
|
||
images: List[dict] = []
|
||
usage: Optional[dict] = None
|
||
errors: List[str] = []
|
||
|
||
for event in stream_response:
|
||
if event is None:
|
||
continue
|
||
|
||
event_type = str(getattr(event, "type", "") or "")
|
||
event_error = _image_response_error_message(event)
|
||
|
||
if event_type.endswith("partial_failed"):
|
||
if event_error:
|
||
errors.append(event_error)
|
||
continue
|
||
|
||
if event_type.endswith("partial_succeeded"):
|
||
try:
|
||
images.append(
|
||
_image_item_to_result(
|
||
event,
|
||
index=int(getattr(event, "image_index", len(images))),
|
||
)
|
||
)
|
||
except Exception as exc:
|
||
errors.append(str(exc))
|
||
continue
|
||
|
||
if event_type.endswith("completed"):
|
||
usage = _image_usage_dict(event)
|
||
|
||
if errors and not images:
|
||
raise ValueError(";".join(errors))
|
||
if not images:
|
||
raise ValueError("图片模型未返回结果")
|
||
|
||
images.sort(key=lambda item: int(item.get("index", 0)))
|
||
return {"images": images, "usage": usage, "errors": errors}
|
||
|
||
response_kwargs = {
|
||
"model": use_model,
|
||
"prompt": prompt,
|
||
"image": image_payload,
|
||
"response_format": "url",
|
||
"size": size or "2K",
|
||
"stream": False,
|
||
"watermark": True,
|
||
"sequential_image_generation": "auto" if image_count > 1 else "disabled",
|
||
"sequential_image_generation_options": (
|
||
SequentialImageGenerationOptions(max_images=image_count)
|
||
if image_count > 1
|
||
else None
|
||
),
|
||
}
|
||
if image_payload is None:
|
||
response_kwargs.pop("image")
|
||
|
||
response = client.images.generate(**response_kwargs)
|
||
|
||
error_message = _image_response_error_message(response)
|
||
if error_message:
|
||
raise ValueError(error_message)
|
||
|
||
data = getattr(response, "data", None) or []
|
||
if not data:
|
||
raise ValueError("图片模型未返回结果")
|
||
|
||
images = [_image_item_to_result(item, index=idx) for idx, item in enumerate(data)]
|
||
return {"images": images, "usage": _image_usage_dict(response), "errors": []}
|
||
|
||
|
||
# ==================== 验证套图效果 ====================
|
||
|
||
|
||
def verify_pattern_result(
|
||
garment_b64: str,
|
||
canvas_b64: str,
|
||
extra_prompt: str = None,
|
||
vision_model: str = None,
|
||
runtime_config: Optional[RuntimeAiConfig] = None,
|
||
) -> str:
|
||
"""用当前视觉模型对比原始成衣和套图结果。"""
|
||
verify_prompt = (
|
||
"请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。"
|
||
"请从花样还原度、裁片覆盖完整度、对齐质量、整体效果四个方面评分,并给出改进建议。"
|
||
)
|
||
if extra_prompt:
|
||
verify_prompt += f"\n用户补充:{extra_prompt}"
|
||
|
||
return call_vision_messages(
|
||
verify_prompt,
|
||
[garment_b64, canvas_b64],
|
||
model_override=vision_model or settings.AI_VISION_MODEL,
|
||
runtime_config=runtime_config,
|
||
)
|
||
|
||
|
||
# ==================== Mock ====================
|
||
|
||
|
||
def mock_reply(message: str, has_image: bool = False) -> str:
|
||
"""未配置 API Key 时的模拟回复。"""
|
||
if has_image:
|
||
return "【模拟分析】收到图片。请先在 API 配置页填写你自己的转发 Key。"
|
||
if "套图" in message:
|
||
return "套图功能已就绪,但正式 AI 分析前请先在 API 配置页填写你自己的转发 Key。"
|
||
return "你好!我是 DesignerCEP AI 助手。请先在 API 配置页填写你的转发地址和 Key,然后我再帮你执行聊天、看图和套图任务。"
|
||
|
||
|
||
# ==================== 调试工具 ====================
|
||
|
||
|
||
def _save_debug_image(data: bytes, filename: str):
|
||
"""保存调试图片到 debug_images 目录。"""
|
||
import time
|
||
|
||
debug_dir = os.path.join(
|
||
os.path.dirname(__file__), "..", "..", "..", "debug_images"
|
||
)
|
||
os.makedirs(debug_dir, exist_ok=True)
|
||
ts = int(time.time())
|
||
path = os.path.join(debug_dir, f"{ts}_{filename}")
|
||
try:
|
||
with open(path, "wb") as file_obj:
|
||
file_obj.write(data)
|
||
log.info(f"[Debug] 图片已保存: {path}")
|
||
except Exception as exc:
|
||
log.warning(f"[Debug] 保存失败: {exc}")
|