Files
DP/Server/app/api/v1/ai_llm.py

669 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
AI 模型调用层
统一管理所有 AI 模型的调用逻辑
当前默认 provider 为 Ark配置项保持通用命名并支持技能化执行策略
"""
from typing import List, Optional
from dataclasses import dataclass
import json
import logging
import os
from app.core.config import settings
from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES
from app.api.v1.ai_skills import AiSkill, get_ai_skill, resolve_ai_skill
log = logging.getLogger(__name__)
@dataclass
class RuntimeAiConfig:
provider: str
api_key: str
base_url: str
chat_base_url: str = ""
vision_base_url: str = ""
image_base_url: str = ""
source: str = "server"
# ==================== Prompts ====================
SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop CEP 插件中。
你的能力:
1. 回答关于 Photoshop 操作和插件使用的问题
2. 通过工具直接操作 Photoshop创建图层、对齐、查看文档信息等
3. 帮用户排查操作中遇到的错误
4. AI 智能套图:两阶段流程,先生成预览确认,再提取套到裁片上
## AI 智能套图流程
当用户上传成衣图片并要求套图时:
1. 先调用 identify_pieces 识别每个图层是什么裁片部位
2. 告诉用户识别结果
3. 调用 generate_garment_preview 生成带标签的裁片预览图
4. 等用户确认 OK 后,再调用 extract_and_apply_all_pieces 正式套图
重要:
- 预览生成后必须等用户确认,不能自己直接进入正式套图
- 每个工具最多调用 1 次,除非上次执行失败
- 用户要求执行 Photoshop 操作时,优先使用工具完成
- 用户如果是在问“不会怎么做 / 为什么不行 / 下一步怎么操作”,要优先结合当前 Photoshop 上下文答疑,不要只给脱离现场的教程
- 答疑、查询当前信息、执行操作可以在同一轮里完成,不要把它们人为拆成两种模式
- 如果用户的问题和当前文档、图层、选区、参考图有关,优先先查 1 个最相关的上下文工具,再给结论
- 如果用户意图已经明确且风险低,可以边解释边执行;如果涉及删除、覆盖、关闭、合并等风险操作,先确认
- 不要只讲原理不动手,也不要只闷头执行不解释;默认输出“判断 + 当前检查/执行结果 + 下一步”
- 用简洁中文回答,适合在小面板中阅读
- 复杂任务先给 2-4 条简短思路,再开始执行
- 默认按 观察 -> 判断 -> 执行 -> 验证 的顺序做事
- 如果任务包含多步骤,不要一上来同时调很多工具,先完成当前阶段再进入下一阶段
- 回答里尽量显式说明“当前阶段”和“下一步”
"""
VISION_PROMPT = """你是一位资深的服装设计分析师,同时也是 DesignerCEP 的 AI 助手。
请结合图片和用户问题,从服装类别、面料、颜色印花、版型特点、工艺细节、设计评价、改进建议等维度做专业分析。
规则:
- 如果图片不是服装相关,也尽力分析图片内容
- 用清晰、结构化的中文回答
- 回答要专业但不啰嗦,突出重点
"""
TOOL_BY_NAME = {
tool["function"]["name"]: tool
for tool in PS_TOOLS
if tool.get("type") == "function" and tool.get("function", {}).get("name")
}
# ==================== 客户端工厂 ====================
def get_runtime_ai_config(user=None) -> RuntimeAiConfig:
user_provider = (getattr(user, "ai_provider", "") or "").strip().lower()
user_api_key = (getattr(user, "ai_api_key", "") or "").strip()
user_base_url = (getattr(user, "ai_base_url", "") or "").strip()
user_chat_base_url = (getattr(user, "ai_chat_base_url", "") or "").strip()
user_vision_base_url = (getattr(user, "ai_vision_base_url", "") or "").strip()
user_image_base_url = (getattr(user, "ai_image_base_url", "") or "").strip()
default_base_url = (
settings.AI_BASE_URL
or settings.ARK_BASE_URL
or "https://ark.cn-beijing.volces.com/api/v3"
)
if user_api_key:
return RuntimeAiConfig(
provider=user_provider or (settings.AI_PROVIDER or "ark").strip().lower(),
api_key=user_api_key,
base_url=user_base_url or default_base_url,
chat_base_url=user_chat_base_url or settings.AI_CHAT_BASE_URL or "",
vision_base_url=user_vision_base_url or settings.AI_VISION_BASE_URL or "",
image_base_url=user_image_base_url or settings.AI_IMAGE_BASE_URL or "",
source="user",
)
return RuntimeAiConfig(
provider=(settings.AI_PROVIDER or "ark").strip().lower(),
api_key=(
settings.AI_API_KEY
or settings.ARK_API_KEY
or os.getenv("AI_API_KEY", "")
or os.getenv("ARK_API_KEY", "")
),
base_url=default_base_url,
chat_base_url=settings.AI_CHAT_BASE_URL or "",
vision_base_url=settings.AI_VISION_BASE_URL or "",
image_base_url=settings.AI_IMAGE_BASE_URL or "",
source="server",
)
def get_ai_provider(runtime_config: Optional[RuntimeAiConfig] = None) -> str:
return (runtime_config.provider if runtime_config else settings.AI_PROVIDER or "ark").strip().lower()
def get_ai_api_key(runtime_config: Optional[RuntimeAiConfig] = None) -> str:
return (runtime_config.api_key if runtime_config else get_runtime_ai_config().api_key).strip()
def get_ai_base_url(
runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat"
) -> str:
config = runtime_config or get_runtime_ai_config()
if capability == "vision":
return (config.vision_base_url or config.base_url).strip()
if capability == "image":
return (config.image_base_url or config.base_url).strip()
return (config.chat_base_url or config.base_url).strip()
def has_ai_config(runtime_config: Optional[RuntimeAiConfig] = None) -> bool:
return bool(get_ai_api_key(runtime_config))
def get_ai_client(
runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat"
):
"""获取当前 provider 的客户端。"""
from volcenginesdkarkruntime import Ark
provider = get_ai_provider(runtime_config)
if provider != "ark":
raise ValueError(f"当前暂不支持的 AI_PROVIDER: {provider}")
api_key = get_ai_api_key(runtime_config)
if not api_key:
raise ValueError("AI_API_KEY 未配置")
return Ark(
base_url=get_ai_base_url(runtime_config, capability=capability),
api_key=api_key,
)
def is_gemini_model(model_name: str) -> bool:
"""兼容旧代码,当前已统一走豆包。"""
return False
def _tool_schemas_for_skill(skill: AiSkill) -> List[dict]:
if not skill.allowed_tools:
return []
return [TOOL_BY_NAME[name] for name in skill.allowed_tools if name in TOOL_BY_NAME]
def _numbered_block(items: tuple[str, ...]) -> str:
if not items:
return ""
return "\n".join(f"{idx + 1}. {item}" for idx, item in enumerate(items))
def _bullet_block(items: tuple[str, ...]) -> str:
if not items:
return ""
return "\n".join(f"- {item}" for item in items)
def _build_skill_prompt(skill: AiSkill, has_image: bool = False) -> str:
workflow = _numbered_block(skill.workflow)
guardrails = _bullet_block(skill.guardrails)
success_criteria = _bullet_block(skill.success_criteria)
triage_questions = _numbered_block(skill.triage_questions)
deliverables = _bullet_block(skill.deliverables)
execution_notes = _bullet_block(skill.execution_notes)
tool_names = (
"".join(TOOL_DISPLAY_NAMES.get(name, name) for name in skill.allowed_tools)
if skill.allowed_tools
else "当前技能以分析和建议为主,不主动调用 Photoshop 工具。"
)
origin = skill.origin or "项目内置技能"
origin_url = skill.origin_url or ""
upstream_skill = skill.upstream_skill or ""
return f"""当前启用技能:{skill.name}{skill.id}
技能定位:{skill.description}
执行模式:{skill.mode}
规划风格:{skill.planning_style}
图片提示:{skill.image_hint or ''}
来源:{origin}
上游技能:{upstream_skill}
来源地址:{origin_url}
推荐工作流:
{workflow}
预判问题:
{triage_questions}
期望交付:
{deliverables}
执行边界:
{guardrails}
执行备注:
{execution_notes}
成功标准:
{success_criteria}
本技能优先工具:
{tool_names}
本轮补充要求:
- 回复保持中文、简洁、能落地。
- 涉及具体执行时,先根据当前技能给出简短思路,再调用工具。
- 根据用户问题动态混合答疑、查询与执行;如果用户是在求助“怎么做/为什么失败/下一步怎么弄”,不要把答疑和执行拆成两轮。
- 能读取当前 Photoshop 上下文时,优先读当前信息后再回答;不要只给泛化教程。
- 优先先想清楚计划,再动手执行;如果任务复杂,先向用户同步 2 到 4 条阶段计划。
- 每轮优先调用最相关的 1 个工具,避免一口气乱调很多工具。
- 如果图片信息不足、文档上下文不足或风险较高,要先说明再操作。
- 每完成一个阶段,要用一句中文总结已完成内容,并指向下一步。
- {'本轮带有图片,请结合图片内容和技能策略做判断。' if has_image else '本轮没有新图片,优先结合聊天历史和 Photoshop 上下文。'}
"""
def _chat_system_prompt(skill: AiSkill, has_image: bool = False) -> str:
return f"{SYSTEM_PROMPT}\n\n{_build_skill_prompt(skill, has_image)}"
def _vision_instructions(skill: AiSkill) -> str:
return f"{VISION_PROMPT}\n\n{_build_skill_prompt(skill, has_image=True)}"
def _messages_with_system(
messages_history: List[dict], skill: AiSkill, has_image: bool = False
) -> List[dict]:
return [{"role": "system", "content": _chat_system_prompt(skill, has_image)}] + messages_history
def _extract_response_text(response) -> str:
texts: List[str] = []
for output_item in getattr(response, "output", []) or []:
for content_item in getattr(output_item, "content", []) or []:
if getattr(content_item, "type", "") == "output_text":
text_value = getattr(content_item, "text", "")
if text_value:
texts.append(text_value)
return "\n".join(texts).strip()
def _image_data_url(image_base64: str) -> str:
return f"data:image/jpeg;base64,{image_base64}"
# ==================== 聊天 / 工具调用 ====================
def call_llm_with_tools(
messages_history: List[dict],
model_override: str = None,
skill_id: Optional[str] = None,
has_image: bool = False,
runtime_config: Optional[RuntimeAiConfig] = None,
):
"""调用当前配置的对话模型,支持 function calling。"""
use_model = model_override or settings.AI_MODEL
client = get_ai_client(runtime_config, capability="chat")
skill = get_ai_skill(skill_id) or resolve_ai_skill(
message=str(messages_history[-1].get("content", "")) if messages_history else "",
history_messages=messages_history[:-1],
requested_skill_id=skill_id,
has_image=has_image,
)
messages = _messages_with_system(messages_history, skill, has_image)
tools = _tool_schemas_for_skill(skill)
log.info(f"{'=' * 60}")
log.info(f"[LLM] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}")
log.info(f"[LLM] skill={skill.id}")
log.info(f"[LLM] 消息数量: {len(messages)}")
log.info(f"[LLM] 工具数量: {len(tools)}")
request_kwargs = {
"model": use_model,
"messages": messages,
"temperature": 0.3,
}
if tools:
request_kwargs.update(
{
"tools": tools,
"tool_choice": "auto",
"parallel_tool_calls": False,
}
)
completion = client.chat.completions.create(**request_kwargs)
choice = completion.choices[0]
message = choice.message
tool_calls = getattr(message, "tool_calls", None) or []
if tool_calls:
tool_calls_data = []
for tc in tool_calls:
args = {}
function = getattr(tc, "function", None)
raw_args = getattr(function, "arguments", "") if function else ""
if raw_args:
try:
args = json.loads(raw_args)
except json.JSONDecodeError:
args = {}
function_name = getattr(function, "name", "") if function else ""
tool_calls_data.append(
{
"id": getattr(tc, "id", function_name),
"name": function_name,
"display_name": TOOL_DISPLAY_NAMES.get(function_name, function_name),
"args": args,
"status": "pending",
}
)
return getattr(message, "content", "") or "", tool_calls_data, skill.to_public_dict()
return getattr(message, "content", "") or "", None, skill.to_public_dict()
def call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None):
"""兼容旧接口,内部统一转到当前视觉模型。"""
prompt = ""
for msg in messages_history:
if msg.get("role") == "user" and msg.get("content"):
prompt = str(msg["content"])
if images_b64:
return call_vision_messages(prompt or "请分析这张图片。", images_b64, model)
reply, _, _ = call_llm_with_tools(messages_history, model_override=model)
return reply
def call_gemini_with_tools(messages_history: List[dict], model: str) -> tuple:
"""兼容旧接口,内部统一转到当前工具调用。"""
return call_llm_with_tools(messages_history, model_override=model)
# ==================== 视觉模型 ====================
def call_vision_messages(
user_message: str,
images_b64: List[str],
model_override: str = None,
instructions_override: Optional[str] = None,
runtime_config: Optional[RuntimeAiConfig] = None,
) -> str:
"""调用当前视觉模型分析图片。"""
use_model = model_override or settings.AI_VISION_MODEL
client = get_ai_client(runtime_config, capability="vision")
content = []
for image_b64 in images_b64:
content.append({"type": "input_image", "image_url": _image_data_url(image_b64)})
content.append({"type": "input_text", "text": user_message})
response = client.responses.create(
model=use_model,
instructions=instructions_override or VISION_PROMPT,
input=[{"role": "user", "content": content}],
)
result = _extract_response_text(response)
log.info(
f"[Vision] provider={get_ai_provider(runtime_config)} model={use_model} 输出长度: {len(result)}"
)
return result
def call_vision_llm(
user_message: str,
image_base64: str,
history: List[dict],
model_override: str = None,
skill_id: Optional[str] = None,
runtime_config: Optional[RuntimeAiConfig] = None,
) -> str:
"""兼容旧接口,调用当前视觉模型。"""
skill = get_ai_skill(skill_id) or resolve_ai_skill(
message=user_message,
history_messages=history,
requested_skill_id=skill_id,
has_image=True,
)
combined_message = user_message
if history:
recent_text = "\n".join(
str(item.get("content", "")) for item in history[-5:] if item.get("content")
)
if recent_text:
combined_message = f"历史对话:\n{recent_text}\n\n当前问题:{user_message}"
return call_vision_messages(
combined_message,
[image_base64],
model_override,
instructions_override=_vision_instructions(skill),
runtime_config=runtime_config,
)
# ==================== 图片编辑 / 生成 ====================
def call_image_model(
images_b64: List[str],
prompt: str,
model_override: str = None,
runtime_config: Optional[RuntimeAiConfig] = None,
) -> tuple:
"""调用当前图片模型,返回 (url_or_datauri, description)。"""
result = call_image_model_batch(
images_b64=images_b64,
prompt=prompt,
model_override=model_override,
size="2K",
max_images=1,
stream=False,
runtime_config=runtime_config,
)
first_image = result["images"][0]
return str(first_image["url"]), ""
def _image_response_error_message(payload) -> str:
error = getattr(payload, "error", None)
return str(getattr(error, "message", "") or "").strip()
def _image_usage_dict(payload) -> Optional[dict]:
usage = getattr(payload, "usage", None)
if usage is None:
return None
if hasattr(usage, "model_dump"):
return usage.model_dump(mode="json")
if isinstance(usage, dict):
return usage
return None
def _image_item_to_result(image_item, index: int) -> dict:
image_url = getattr(image_item, "url", "") or ""
image_b64 = getattr(image_item, "b64_json", "") or ""
image_size = getattr(image_item, "size", "") or ""
if image_url:
return {"index": index, "url": image_url, "size": image_size}
if image_b64:
padding = (-len(image_b64)) % 4
if padding:
image_b64 += "=" * padding
return {
"index": index,
"url": f"data:image/png;base64,{image_b64}",
"size": image_size,
}
raise ValueError("图片模型返回了空图片结果")
def call_image_model_batch(
images_b64: List[str],
prompt: str,
model_override: str = None,
size: str = "2K",
max_images: int = 1,
stream: Optional[bool] = None,
runtime_config: Optional[RuntimeAiConfig] = None,
) -> dict:
"""调用当前图片模型,支持单图、参考图生图和连续组图。"""
from volcenginesdkarkruntime.types.images.images import (
SequentialImageGenerationOptions,
)
use_model = model_override or settings.AI_IMAGE_EDIT_MODEL
client = get_ai_client(runtime_config, capability="image")
image_count = max(1, min(int(max_images or 1), 4))
should_stream = (image_count > 1) if stream is None else bool(stream)
image_inputs = [_image_data_url(img_b64) for img_b64 in images_b64] or None
image_payload = None
if image_inputs:
image_payload = image_inputs[0] if len(image_inputs) == 1 else image_inputs
log.info(f"{'=' * 60}")
log.info(f"[ImageModel] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}")
log.info(f"[ImageModel] 图片数: {len(images_b64)}")
log.info(f"[ImageModel] 目标生成张数: {image_count}")
log.info(f"[ImageModel] 提示词: {prompt[:200]}")
request_kwargs = {
"model": use_model,
"prompt": prompt,
"image": image_payload,
"response_format": "url",
"size": size or "2K",
"watermark": True,
"stream": should_stream,
"sequential_image_generation": "auto" if image_count > 1 else "disabled",
}
if image_payload is None:
request_kwargs.pop("image")
if image_count > 1:
request_kwargs["sequential_image_generation_options"] = (
SequentialImageGenerationOptions(max_images=image_count)
)
if should_stream:
stream_response = client.images.generate(**request_kwargs)
images: List[dict] = []
usage: Optional[dict] = None
errors: List[str] = []
for event in stream_response:
if event is None:
continue
event_type = str(getattr(event, "type", "") or "")
event_error = _image_response_error_message(event)
if event_type.endswith("partial_failed"):
if event_error:
errors.append(event_error)
continue
if event_type.endswith("partial_succeeded"):
try:
images.append(
_image_item_to_result(
event,
index=int(getattr(event, "image_index", len(images))),
)
)
except Exception as exc:
errors.append(str(exc))
continue
if event_type.endswith("completed"):
usage = _image_usage_dict(event)
if errors and not images:
raise ValueError("".join(errors))
if not images:
raise ValueError("图片模型未返回结果")
images.sort(key=lambda item: int(item.get("index", 0)))
return {"images": images, "usage": usage, "errors": errors}
response_kwargs = {
"model": use_model,
"prompt": prompt,
"image": image_payload,
"response_format": "url",
"size": size or "2K",
"stream": False,
"watermark": True,
"sequential_image_generation": "auto" if image_count > 1 else "disabled",
"sequential_image_generation_options": (
SequentialImageGenerationOptions(max_images=image_count)
if image_count > 1
else None
),
}
if image_payload is None:
response_kwargs.pop("image")
response = client.images.generate(**response_kwargs)
error_message = _image_response_error_message(response)
if error_message:
raise ValueError(error_message)
data = getattr(response, "data", None) or []
if not data:
raise ValueError("图片模型未返回结果")
images = [_image_item_to_result(item, index=idx) for idx, item in enumerate(data)]
return {"images": images, "usage": _image_usage_dict(response), "errors": []}
# ==================== 验证套图效果 ====================
def verify_pattern_result(
garment_b64: str,
canvas_b64: str,
extra_prompt: str = None,
vision_model: str = None,
runtime_config: Optional[RuntimeAiConfig] = None,
) -> str:
"""用当前视觉模型对比原始成衣和套图结果。"""
verify_prompt = (
"请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。"
"请从花样还原度、裁片覆盖完整度、对齐质量、整体效果四个方面评分,并给出改进建议。"
)
if extra_prompt:
verify_prompt += f"\n用户补充:{extra_prompt}"
return call_vision_messages(
verify_prompt,
[garment_b64, canvas_b64],
model_override=vision_model or settings.AI_VISION_MODEL,
runtime_config=runtime_config,
)
# ==================== Mock ====================
def mock_reply(message: str, has_image: bool = False) -> str:
"""未配置 API Key 时的模拟回复。"""
if has_image:
return "【模拟分析】收到图片。请先在 API 配置页填写你自己的转发 Key。"
if "套图" in message:
return "套图功能已就绪,但正式 AI 分析前请先在 API 配置页填写你自己的转发 Key。"
return "你好!我是 DesignerCEP AI 助手。请先在 API 配置页填写你的转发地址和 Key然后我再帮你执行聊天、看图和套图任务。"
# ==================== 调试工具 ====================
def _save_debug_image(data: bytes, filename: str):
"""保存调试图片到 debug_images 目录。"""
import time
debug_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "debug_images"
)
os.makedirs(debug_dir, exist_ok=True)
ts = int(time.time())
path = os.path.join(debug_dir, f"{ts}_{filename}")
try:
with open(path, "wb") as file_obj:
file_obj.write(data)
log.info(f"[Debug] 图片已保存: {path}")
except Exception as exc:
log.warning(f"[Debug] 保存失败: {exc}")