# -*- coding: utf-8 -*- """ AI 模型调用层 统一管理所有 AI 模型的调用逻辑 当前默认 provider 为 Ark,配置项保持通用命名,并支持技能化执行策略 """ from typing import List, Optional from dataclasses import dataclass import json import logging import os from app.core.config import settings from app.api.v1.ai_tools import PS_TOOLS, TOOL_DISPLAY_NAMES from app.api.v1.ai_skills import AiSkill, get_ai_skill, resolve_ai_skill log = logging.getLogger(__name__) @dataclass class RuntimeAiConfig: provider: str api_key: str base_url: str chat_base_url: str = "" vision_base_url: str = "" image_base_url: str = "" source: str = "server" # ==================== Prompts ==================== SYSTEM_PROMPT = """你是 DesignerCEP 的 AI 助手,运行在 Adobe Photoshop CEP 插件中。 你的能力: 1. 回答关于 Photoshop 操作和插件使用的问题 2. 通过工具直接操作 Photoshop(创建图层、对齐、查看文档信息等) 3. 帮用户排查操作中遇到的错误 4. AI 智能套图:两阶段流程,先生成预览确认,再提取套到裁片上 ## AI 智能套图流程 当用户上传成衣图片并要求套图时: 1. 先调用 identify_pieces 识别每个图层是什么裁片部位 2. 告诉用户识别结果 3. 调用 generate_garment_preview 生成带标签的裁片预览图 4. 等用户确认 OK 后,再调用 extract_and_apply_all_pieces 正式套图 重要: - 预览生成后必须等用户确认,不能自己直接进入正式套图 - 每个工具最多调用 1 次,除非上次执行失败 - 用户要求执行 Photoshop 操作时,优先使用工具完成 - 用户如果是在问“不会怎么做 / 为什么不行 / 下一步怎么操作”,要优先结合当前 Photoshop 上下文答疑,不要只给脱离现场的教程 - 答疑、查询当前信息、执行操作可以在同一轮里完成,不要把它们人为拆成两种模式 - 如果用户的问题和当前文档、图层、选区、参考图有关,优先先查 1 个最相关的上下文工具,再给结论 - 如果用户意图已经明确且风险低,可以边解释边执行;如果涉及删除、覆盖、关闭、合并等风险操作,先确认 - 不要只讲原理不动手,也不要只闷头执行不解释;默认输出“判断 + 当前检查/执行结果 + 下一步” - 用简洁中文回答,适合在小面板中阅读 - 复杂任务先给 2-4 条简短思路,再开始执行 - 默认按 观察 -> 判断 -> 执行 -> 验证 的顺序做事 - 如果任务包含多步骤,不要一上来同时调很多工具,先完成当前阶段再进入下一阶段 - 回答里尽量显式说明“当前阶段”和“下一步” """ VISION_PROMPT = """你是一位资深的服装设计分析师,同时也是 DesignerCEP 的 AI 助手。 请结合图片和用户问题,从服装类别、面料、颜色印花、版型特点、工艺细节、设计评价、改进建议等维度做专业分析。 规则: - 如果图片不是服装相关,也尽力分析图片内容 - 用清晰、结构化的中文回答 - 回答要专业但不啰嗦,突出重点 """ TOOL_BY_NAME = { tool["function"]["name"]: tool for tool in PS_TOOLS if tool.get("type") == "function" and tool.get("function", {}).get("name") } # ==================== 客户端工厂 ==================== def get_runtime_ai_config(user=None) -> RuntimeAiConfig: user_provider = (getattr(user, "ai_provider", "") or "").strip().lower() user_api_key = (getattr(user, "ai_api_key", "") or "").strip() user_base_url = (getattr(user, "ai_base_url", "") or "").strip() user_chat_base_url = (getattr(user, "ai_chat_base_url", "") or "").strip() user_vision_base_url = (getattr(user, "ai_vision_base_url", "") or "").strip() user_image_base_url = (getattr(user, "ai_image_base_url", "") or "").strip() default_base_url = ( settings.AI_BASE_URL or settings.ARK_BASE_URL or "https://ark.cn-beijing.volces.com/api/v3" ) if user_api_key: return RuntimeAiConfig( provider=user_provider or (settings.AI_PROVIDER or "ark").strip().lower(), api_key=user_api_key, base_url=user_base_url or default_base_url, chat_base_url=user_chat_base_url or settings.AI_CHAT_BASE_URL or "", vision_base_url=user_vision_base_url or settings.AI_VISION_BASE_URL or "", image_base_url=user_image_base_url or settings.AI_IMAGE_BASE_URL or "", source="user", ) return RuntimeAiConfig( provider=(settings.AI_PROVIDER or "ark").strip().lower(), api_key=( settings.AI_API_KEY or settings.ARK_API_KEY or os.getenv("AI_API_KEY", "") or os.getenv("ARK_API_KEY", "") ), base_url=default_base_url, chat_base_url=settings.AI_CHAT_BASE_URL or "", vision_base_url=settings.AI_VISION_BASE_URL or "", image_base_url=settings.AI_IMAGE_BASE_URL or "", source="server", ) def get_ai_provider(runtime_config: Optional[RuntimeAiConfig] = None) -> str: return (runtime_config.provider if runtime_config else settings.AI_PROVIDER or "ark").strip().lower() def get_ai_api_key(runtime_config: Optional[RuntimeAiConfig] = None) -> str: return (runtime_config.api_key if runtime_config else get_runtime_ai_config().api_key).strip() def get_ai_base_url( runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat" ) -> str: config = runtime_config or get_runtime_ai_config() if capability == "vision": return (config.vision_base_url or config.base_url).strip() if capability == "image": return (config.image_base_url or config.base_url).strip() return (config.chat_base_url or config.base_url).strip() def has_ai_config(runtime_config: Optional[RuntimeAiConfig] = None) -> bool: return bool(get_ai_api_key(runtime_config)) def get_ai_client( runtime_config: Optional[RuntimeAiConfig] = None, capability: str = "chat" ): """获取当前 provider 的客户端。""" from volcenginesdkarkruntime import Ark provider = get_ai_provider(runtime_config) if provider != "ark": raise ValueError(f"当前暂不支持的 AI_PROVIDER: {provider}") api_key = get_ai_api_key(runtime_config) if not api_key: raise ValueError("AI_API_KEY 未配置") return Ark( base_url=get_ai_base_url(runtime_config, capability=capability), api_key=api_key, ) def is_gemini_model(model_name: str) -> bool: """兼容旧代码,当前已统一走豆包。""" return False def _tool_schemas_for_skill(skill: AiSkill) -> List[dict]: if not skill.allowed_tools: return [] return [TOOL_BY_NAME[name] for name in skill.allowed_tools if name in TOOL_BY_NAME] def _numbered_block(items: tuple[str, ...]) -> str: if not items: return "无" return "\n".join(f"{idx + 1}. {item}" for idx, item in enumerate(items)) def _bullet_block(items: tuple[str, ...]) -> str: if not items: return "无" return "\n".join(f"- {item}" for item in items) def _build_skill_prompt(skill: AiSkill, has_image: bool = False) -> str: workflow = _numbered_block(skill.workflow) guardrails = _bullet_block(skill.guardrails) success_criteria = _bullet_block(skill.success_criteria) triage_questions = _numbered_block(skill.triage_questions) deliverables = _bullet_block(skill.deliverables) execution_notes = _bullet_block(skill.execution_notes) tool_names = ( "、".join(TOOL_DISPLAY_NAMES.get(name, name) for name in skill.allowed_tools) if skill.allowed_tools else "当前技能以分析和建议为主,不主动调用 Photoshop 工具。" ) origin = skill.origin or "项目内置技能" origin_url = skill.origin_url or "无" upstream_skill = skill.upstream_skill or "无" return f"""当前启用技能:{skill.name}({skill.id}) 技能定位:{skill.description} 执行模式:{skill.mode} 规划风格:{skill.planning_style} 图片提示:{skill.image_hint or '无'} 来源:{origin} 上游技能:{upstream_skill} 来源地址:{origin_url} 推荐工作流: {workflow} 预判问题: {triage_questions} 期望交付: {deliverables} 执行边界: {guardrails} 执行备注: {execution_notes} 成功标准: {success_criteria} 本技能优先工具: {tool_names} 本轮补充要求: - 回复保持中文、简洁、能落地。 - 涉及具体执行时,先根据当前技能给出简短思路,再调用工具。 - 根据用户问题动态混合答疑、查询与执行;如果用户是在求助“怎么做/为什么失败/下一步怎么弄”,不要把答疑和执行拆成两轮。 - 能读取当前 Photoshop 上下文时,优先读当前信息后再回答;不要只给泛化教程。 - 优先先想清楚计划,再动手执行;如果任务复杂,先向用户同步 2 到 4 条阶段计划。 - 每轮优先调用最相关的 1 个工具,避免一口气乱调很多工具。 - 如果图片信息不足、文档上下文不足或风险较高,要先说明再操作。 - 每完成一个阶段,要用一句中文总结已完成内容,并指向下一步。 - {'本轮带有图片,请结合图片内容和技能策略做判断。' if has_image else '本轮没有新图片,优先结合聊天历史和 Photoshop 上下文。'} """ def _chat_system_prompt(skill: AiSkill, has_image: bool = False) -> str: return f"{SYSTEM_PROMPT}\n\n{_build_skill_prompt(skill, has_image)}" def _vision_instructions(skill: AiSkill) -> str: return f"{VISION_PROMPT}\n\n{_build_skill_prompt(skill, has_image=True)}" def _messages_with_system( messages_history: List[dict], skill: AiSkill, has_image: bool = False ) -> List[dict]: return [{"role": "system", "content": _chat_system_prompt(skill, has_image)}] + messages_history def _extract_response_text(response) -> str: texts: List[str] = [] for output_item in getattr(response, "output", []) or []: for content_item in getattr(output_item, "content", []) or []: if getattr(content_item, "type", "") == "output_text": text_value = getattr(content_item, "text", "") if text_value: texts.append(text_value) return "\n".join(texts).strip() def _image_data_url(image_base64: str) -> str: return f"data:image/jpeg;base64,{image_base64}" # ==================== 聊天 / 工具调用 ==================== def call_llm_with_tools( messages_history: List[dict], model_override: str = None, skill_id: Optional[str] = None, has_image: bool = False, runtime_config: Optional[RuntimeAiConfig] = None, ): """调用当前配置的对话模型,支持 function calling。""" use_model = model_override or settings.AI_MODEL client = get_ai_client(runtime_config, capability="chat") skill = get_ai_skill(skill_id) or resolve_ai_skill( message=str(messages_history[-1].get("content", "")) if messages_history else "", history_messages=messages_history[:-1], requested_skill_id=skill_id, has_image=has_image, ) messages = _messages_with_system(messages_history, skill, has_image) tools = _tool_schemas_for_skill(skill) log.info(f"{'=' * 60}") log.info(f"[LLM] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}") log.info(f"[LLM] skill={skill.id}") log.info(f"[LLM] 消息数量: {len(messages)}") log.info(f"[LLM] 工具数量: {len(tools)}") request_kwargs = { "model": use_model, "messages": messages, "temperature": 0.3, } if tools: request_kwargs.update( { "tools": tools, "tool_choice": "auto", "parallel_tool_calls": False, } ) completion = client.chat.completions.create(**request_kwargs) choice = completion.choices[0] message = choice.message tool_calls = getattr(message, "tool_calls", None) or [] if tool_calls: tool_calls_data = [] for tc in tool_calls: args = {} function = getattr(tc, "function", None) raw_args = getattr(function, "arguments", "") if function else "" if raw_args: try: args = json.loads(raw_args) except json.JSONDecodeError: args = {} function_name = getattr(function, "name", "") if function else "" tool_calls_data.append( { "id": getattr(tc, "id", function_name), "name": function_name, "display_name": TOOL_DISPLAY_NAMES.get(function_name, function_name), "args": args, "status": "pending", } ) return getattr(message, "content", "") or "", tool_calls_data, skill.to_public_dict() return getattr(message, "content", "") or "", None, skill.to_public_dict() def call_gemini(messages_history: List[dict], model: str, images_b64: List[str] = None): """兼容旧接口,内部统一转到当前视觉模型。""" prompt = "" for msg in messages_history: if msg.get("role") == "user" and msg.get("content"): prompt = str(msg["content"]) if images_b64: return call_vision_messages(prompt or "请分析这张图片。", images_b64, model) reply, _, _ = call_llm_with_tools(messages_history, model_override=model) return reply def call_gemini_with_tools(messages_history: List[dict], model: str) -> tuple: """兼容旧接口,内部统一转到当前工具调用。""" return call_llm_with_tools(messages_history, model_override=model) # ==================== 视觉模型 ==================== def call_vision_messages( user_message: str, images_b64: List[str], model_override: str = None, instructions_override: Optional[str] = None, runtime_config: Optional[RuntimeAiConfig] = None, ) -> str: """调用当前视觉模型分析图片。""" use_model = model_override or settings.AI_VISION_MODEL client = get_ai_client(runtime_config, capability="vision") content = [] for image_b64 in images_b64: content.append({"type": "input_image", "image_url": _image_data_url(image_b64)}) content.append({"type": "input_text", "text": user_message}) response = client.responses.create( model=use_model, instructions=instructions_override or VISION_PROMPT, input=[{"role": "user", "content": content}], ) result = _extract_response_text(response) log.info( f"[Vision] provider={get_ai_provider(runtime_config)} model={use_model} 输出长度: {len(result)}" ) return result def call_vision_llm( user_message: str, image_base64: str, history: List[dict], model_override: str = None, skill_id: Optional[str] = None, runtime_config: Optional[RuntimeAiConfig] = None, ) -> str: """兼容旧接口,调用当前视觉模型。""" skill = get_ai_skill(skill_id) or resolve_ai_skill( message=user_message, history_messages=history, requested_skill_id=skill_id, has_image=True, ) combined_message = user_message if history: recent_text = "\n".join( str(item.get("content", "")) for item in history[-5:] if item.get("content") ) if recent_text: combined_message = f"历史对话:\n{recent_text}\n\n当前问题:{user_message}" return call_vision_messages( combined_message, [image_base64], model_override, instructions_override=_vision_instructions(skill), runtime_config=runtime_config, ) # ==================== 图片编辑 / 生成 ==================== def call_image_model( images_b64: List[str], prompt: str, model_override: str = None, runtime_config: Optional[RuntimeAiConfig] = None, ) -> tuple: """调用当前图片模型,返回 (url_or_datauri, description)。""" result = call_image_model_batch( images_b64=images_b64, prompt=prompt, model_override=model_override, size="2K", max_images=1, stream=False, runtime_config=runtime_config, ) first_image = result["images"][0] return str(first_image["url"]), "" def _image_response_error_message(payload) -> str: error = getattr(payload, "error", None) return str(getattr(error, "message", "") or "").strip() def _image_usage_dict(payload) -> Optional[dict]: usage = getattr(payload, "usage", None) if usage is None: return None if hasattr(usage, "model_dump"): return usage.model_dump(mode="json") if isinstance(usage, dict): return usage return None def _image_item_to_result(image_item, index: int) -> dict: image_url = getattr(image_item, "url", "") or "" image_b64 = getattr(image_item, "b64_json", "") or "" image_size = getattr(image_item, "size", "") or "" if image_url: return {"index": index, "url": image_url, "size": image_size} if image_b64: padding = (-len(image_b64)) % 4 if padding: image_b64 += "=" * padding return { "index": index, "url": f"data:image/png;base64,{image_b64}", "size": image_size, } raise ValueError("图片模型返回了空图片结果") def call_image_model_batch( images_b64: List[str], prompt: str, model_override: str = None, size: str = "2K", max_images: int = 1, stream: Optional[bool] = None, runtime_config: Optional[RuntimeAiConfig] = None, ) -> dict: """调用当前图片模型,支持单图、参考图生图和连续组图。""" from volcenginesdkarkruntime.types.images.images import ( SequentialImageGenerationOptions, ) use_model = model_override or settings.AI_IMAGE_EDIT_MODEL client = get_ai_client(runtime_config, capability="image") image_count = max(1, min(int(max_images or 1), 4)) should_stream = (image_count > 1) if stream is None else bool(stream) image_inputs = [_image_data_url(img_b64) for img_b64 in images_b64] or None image_payload = None if image_inputs: image_payload = image_inputs[0] if len(image_inputs) == 1 else image_inputs log.info(f"{'=' * 60}") log.info(f"[ImageModel] provider={get_ai_provider(runtime_config)} model={use_model} source={(runtime_config.source if runtime_config else 'server')}") log.info(f"[ImageModel] 图片数: {len(images_b64)}") log.info(f"[ImageModel] 目标生成张数: {image_count}") log.info(f"[ImageModel] 提示词: {prompt[:200]}") request_kwargs = { "model": use_model, "prompt": prompt, "image": image_payload, "response_format": "url", "size": size or "2K", "watermark": True, "stream": should_stream, "sequential_image_generation": "auto" if image_count > 1 else "disabled", } if image_payload is None: request_kwargs.pop("image") if image_count > 1: request_kwargs["sequential_image_generation_options"] = ( SequentialImageGenerationOptions(max_images=image_count) ) if should_stream: stream_response = client.images.generate(**request_kwargs) images: List[dict] = [] usage: Optional[dict] = None errors: List[str] = [] for event in stream_response: if event is None: continue event_type = str(getattr(event, "type", "") or "") event_error = _image_response_error_message(event) if event_type.endswith("partial_failed"): if event_error: errors.append(event_error) continue if event_type.endswith("partial_succeeded"): try: images.append( _image_item_to_result( event, index=int(getattr(event, "image_index", len(images))), ) ) except Exception as exc: errors.append(str(exc)) continue if event_type.endswith("completed"): usage = _image_usage_dict(event) if errors and not images: raise ValueError(";".join(errors)) if not images: raise ValueError("图片模型未返回结果") images.sort(key=lambda item: int(item.get("index", 0))) return {"images": images, "usage": usage, "errors": errors} response_kwargs = { "model": use_model, "prompt": prompt, "image": image_payload, "response_format": "url", "size": size or "2K", "stream": False, "watermark": True, "sequential_image_generation": "auto" if image_count > 1 else "disabled", "sequential_image_generation_options": ( SequentialImageGenerationOptions(max_images=image_count) if image_count > 1 else None ), } if image_payload is None: response_kwargs.pop("image") response = client.images.generate(**response_kwargs) error_message = _image_response_error_message(response) if error_message: raise ValueError(error_message) data = getattr(response, "data", None) or [] if not data: raise ValueError("图片模型未返回结果") images = [_image_item_to_result(item, index=idx) for idx, item in enumerate(data)] return {"images": images, "usage": _image_usage_dict(response), "errors": []} # ==================== 验证套图效果 ==================== def verify_pattern_result( garment_b64: str, canvas_b64: str, extra_prompt: str = None, vision_model: str = None, runtime_config: Optional[RuntimeAiConfig] = None, ) -> str: """用当前视觉模型对比原始成衣和套图结果。""" verify_prompt = ( "请对比这两张图片:第一张是原始成衣照片,第二张是套图结果。" "请从花样还原度、裁片覆盖完整度、对齐质量、整体效果四个方面评分,并给出改进建议。" ) if extra_prompt: verify_prompt += f"\n用户补充:{extra_prompt}" return call_vision_messages( verify_prompt, [garment_b64, canvas_b64], model_override=vision_model or settings.AI_VISION_MODEL, runtime_config=runtime_config, ) # ==================== Mock ==================== def mock_reply(message: str, has_image: bool = False) -> str: """未配置 API Key 时的模拟回复。""" if has_image: return "【模拟分析】收到图片。请先在 API 配置页填写你自己的转发 Key。" if "套图" in message: return "套图功能已就绪,但正式 AI 分析前请先在 API 配置页填写你自己的转发 Key。" return "你好!我是 DesignerCEP AI 助手。请先在 API 配置页填写你的转发地址和 Key,然后我再帮你执行聊天、看图和套图任务。" # ==================== 调试工具 ==================== def _save_debug_image(data: bytes, filename: str): """保存调试图片到 debug_images 目录。""" import time debug_dir = os.path.join( os.path.dirname(__file__), "..", "..", "..", "debug_images" ) os.makedirs(debug_dir, exist_ok=True) ts = int(time.time()) path = os.path.join(debug_dir, f"{ts}_{filename}") try: with open(path, "wb") as file_obj: file_obj.write(data) log.info(f"[Debug] 图片已保存: {path}") except Exception as exc: log.warning(f"[Debug] 保存失败: {exc}")