330 lines
12 KiB
Python
330 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
提示词迭代测试脚本
|
||
直接调用视觉模型测试 identify_pieces 的提示词效果
|
||
"""
|
||
import base64, json, os, sys
|
||
|
||
# 加载环境
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
from openai import OpenAI
|
||
from app.core.config import settings
|
||
|
||
# ==================== 配置 ====================
|
||
|
||
GARMENT_IMG = "debug_images/1770525898_input_0.jpg" # 成衣照片
|
||
CANVAS_IMG = "debug_images/1770525898_input_1.jpg" # 裁片轮廓
|
||
|
||
LAYERS_INFO = """组 1 (group, 4120x5280px, 位于路径: /组 1)
|
||
图层 1 (LayerKind.NORMAL, 4120x5280px, 位于路径: /组 1/图层 1)
|
||
组 2 (group, 3999x4561px, 位于路径: /组 2)
|
||
图层 2 (LayerKind.NORMAL, 3999x4561px, 位于路径: /组 2/图层 2)
|
||
组 3 (group, 3529x333px, 位于路径: /组 3)
|
||
图层 3 (LayerKind.NORMAL, 3529x333px, 位于路径: /组 3/图层 3)
|
||
组 4 (group, 2423x2004px, 位于路径: /组 4)
|
||
图层 4 (LayerKind.NORMAL, 2423x2004px, 位于路径: /组 4/图层 4)
|
||
组 5 (group, 2422x2003px, 位于路径: /组 5)
|
||
图层 1 拷贝 (LayerKind.NORMAL, 2422x2003px, 位于路径: /组 5/图层 1 拷贝)"""
|
||
|
||
# ==================== 提示词版本 ====================
|
||
|
||
def make_prompt_v2(layers_info: str) -> str:
|
||
"""v2 提示词:新增 pattern_mode + 不可见部位推理"""
|
||
# 动态提取图层名
|
||
layer_names = []
|
||
for line in layers_info.strip().split('\n'):
|
||
name = line.split('(')[0].strip()
|
||
if name:
|
||
layer_names.append(name)
|
||
ex1 = layer_names[0] if len(layer_names) > 0 else "图层名1"
|
||
ex2 = layer_names[1] if len(layer_names) > 1 else "图层名2"
|
||
ex3 = layer_names[2] if len(layer_names) > 2 else "图层名3"
|
||
|
||
return f"""你需要完成三个任务:
|
||
|
||
**任务1:识别裁片部位**
|
||
Image 2 是 PS 文档中的裁片轮廓图。以下是图层信息:
|
||
{layers_info}
|
||
|
||
请根据形状、大小识别每个**图层组**是什么服装部位(前片、后片、袖子、领口等)。
|
||
⚠️ mapping 的 key 必须是上面图层信息中的**实际图层组名称**(如 "{ex1}"、"{ex2}" 等),不要自己编名字!
|
||
|
||
**任务2:判断花型覆盖模式 pattern_mode**
|
||
Image 1 是成衣照片。请判断这件衣服的花型属于哪种模式:
|
||
|
||
- "all_over":所有裁片来自同一块面料(碎花/格子/渐变等整体花型覆盖全身)
|
||
- "placement":各裁片独立处理(有的印花有的纯色,如卡通卫衣)
|
||
- "color_block":全部纯色拼接
|
||
|
||
如果是 all_over,还需要提供 fabric_info:
|
||
- fabric_type: 面料印花类型(数码印花/提花/丝网印/...)
|
||
- pattern_direction: 花型方向(均匀铺满/上下渐变/左右渐变/斜向)
|
||
- pattern_elements: 花型元素描述
|
||
- color_transition: 色彩过渡描述
|
||
- density: 花型密度描述
|
||
|
||
**任务3:分析每个裁片的图案类型**
|
||
对每个裁片分析:
|
||
- type: solid / fill_pattern / theme_pattern / mixed_pattern / all_over
|
||
- solid 和 theme_pattern 必须给 color(hex值)
|
||
- 如果某个裁片在照片中**看不到**(如后片),设置 visible: false,并根据前片推理:
|
||
- 全身花型 → 后片与前片相同
|
||
- 局部印花 → 后片通常为底色纯色
|
||
- 给出推理理由 inference_reason 和置信度 confidence(high/medium/low)
|
||
- 如果左右对称(如左右袖),标记 mirror_of 字段
|
||
|
||
用 JSON 回答(注意:key 用实际图层名):
|
||
```json
|
||
{{
|
||
"pattern_mode": "all_over",
|
||
"fabric_info": {{
|
||
"fabric_type": "数码印花",
|
||
"pattern_direction": "上下渐变",
|
||
"pattern_elements": "白色百合花朵 + 灰白格纹底",
|
||
"color_transition": "从黑色渐变到灰白色",
|
||
"density": "上部花朵密集大朵,下部稀疏小朵"
|
||
}},
|
||
"mapping": {{
|
||
"{ex1}": "左袖",
|
||
"{ex2}": "前片"
|
||
}},
|
||
"analysis": [
|
||
{{"layer": "{ex1}", "piece": "左袖", "type": "all_over", "visible": true, "description": "黑底白花渐变"}},
|
||
{{"layer": "{ex2}", "piece": "前片", "type": "all_over", "visible": true, "description": "上半黑底大花+下半灰白格纹"}},
|
||
{{"layer": "{ex3}", "piece": "后片", "type": "all_over", "visible": false, "inferred_from": "{ex2}", "inference_reason": "全身花型面料,后片花型与前片相同", "confidence": "high"}}
|
||
],
|
||
"symmetric_pairs": [["{ex1}", "其他对称片"]],
|
||
"base_color": "#1A1A1A"
|
||
}}
|
||
```
|
||
|
||
⚠️ 重要:
|
||
- mapping 和 analysis 中的 layer 必须用**实际图层组名称**
|
||
- 看不到的部位设 visible: false 并说明推理依据
|
||
- 左右对称的片标记 symmetric_pairs
|
||
只返回 JSON,不要其他文字。"""
|
||
|
||
|
||
# ==================== 调用模型 ====================
|
||
|
||
def call_vision(prompt: str, garment_b64: str, canvas_b64: str, model: str = None):
|
||
"""调用视觉模型"""
|
||
use_model = model or settings.AI_VISION_MODEL
|
||
client = OpenAI(
|
||
api_key=settings.AI_API_KEY,
|
||
base_url=settings.AI_BASE_URL or "https://api.openai.com/v1",
|
||
)
|
||
|
||
content = [
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{canvas_b64}"}},
|
||
{"type": "text", "text": prompt},
|
||
]
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"模型: {use_model}")
|
||
print(f"提示词长度: {len(prompt)} 字符")
|
||
print(f"{'='*60}\n")
|
||
|
||
completion = client.chat.completions.create(
|
||
model=use_model,
|
||
messages=[{"role": "user", "content": content}],
|
||
)
|
||
|
||
return completion.choices[0].message.content or ""
|
||
|
||
|
||
def parse_result(raw: str) -> dict:
|
||
"""解析 JSON 结果"""
|
||
clean = raw.strip()
|
||
if "```" in clean:
|
||
clean = clean.split("```")[1]
|
||
if clean.startswith("json"):
|
||
clean = clean[4:]
|
||
clean = clean.strip()
|
||
try:
|
||
return json.loads(clean)
|
||
except json.JSONDecodeError as e:
|
||
print(f"⚠️ JSON 解析失败: {e}")
|
||
print(f"原始回复:\n{raw[:500]}")
|
||
return {}
|
||
|
||
|
||
def evaluate_result(result: dict) -> list:
|
||
"""评估结果质量,返回问题列表"""
|
||
issues = []
|
||
|
||
if not result:
|
||
issues.append("❌ 解析失败,无结果")
|
||
return issues
|
||
|
||
# 检查 pattern_mode
|
||
mode = result.get("pattern_mode")
|
||
if not mode:
|
||
issues.append("❌ 缺少 pattern_mode")
|
||
elif mode not in ("all_over", "placement", "color_block"):
|
||
issues.append(f"⚠️ pattern_mode 值异常: {mode}")
|
||
|
||
# 检查 mapping 是否用了真实图层名
|
||
mapping = result.get("mapping", {})
|
||
for key in mapping:
|
||
if key.startswith("M-") or key.startswith("Layer"):
|
||
issues.append(f"❌ mapping key '{key}' 不是真实图层名(应为 '组 X' 格式)")
|
||
|
||
# 检查 analysis
|
||
analysis = result.get("analysis", [])
|
||
if not analysis:
|
||
issues.append("❌ 缺少 analysis")
|
||
else:
|
||
has_invisible = any(a.get("visible") == False for a in analysis)
|
||
if not has_invisible:
|
||
issues.append("⚠️ 没有标记不可见部位(后片通常看不到)")
|
||
|
||
has_mirror = any("mirror_of" in a for a in analysis)
|
||
symmetric = result.get("symmetric_pairs", [])
|
||
if not has_mirror and not symmetric:
|
||
issues.append("⚠️ 没有标记对称片(左右袖应该对称)")
|
||
|
||
for a in analysis:
|
||
if a.get("type") in ("solid", "theme_pattern") and not a.get("color"):
|
||
issues.append(f"⚠️ {a.get('layer')} 类型为 {a['type']} 但缺少 color")
|
||
|
||
# 检查 fabric_info(仅 all_over)
|
||
if mode == "all_over":
|
||
fi = result.get("fabric_info")
|
||
if not fi:
|
||
issues.append("❌ all_over 模式缺少 fabric_info")
|
||
else:
|
||
for key in ("fabric_type", "pattern_direction", "pattern_elements", "color_transition", "density"):
|
||
if not fi.get(key):
|
||
issues.append(f"⚠️ fabric_info 缺少 {key}")
|
||
|
||
if not issues:
|
||
issues.append("✅ 完美!所有字段齐全")
|
||
|
||
return issues
|
||
|
||
|
||
# ==================== 主流程 ====================
|
||
|
||
def main():
|
||
# 修复 Windows GBK 输出
|
||
import io
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||
|
||
# 加载测试图片
|
||
print("[1] Loading test images...")
|
||
with open(GARMENT_IMG, "rb") as f:
|
||
garment_b64 = base64.b64encode(f.read()).decode()
|
||
with open(CANVAS_IMG, "rb") as f:
|
||
canvas_b64 = base64.b64encode(f.read()).decode()
|
||
print(f" garment: {len(garment_b64)//1024}KB, canvas: {len(canvas_b64)//1024}KB")
|
||
|
||
# 生成提示词
|
||
prompt = make_prompt_v2(LAYERS_INFO)
|
||
|
||
# 调用模型
|
||
print("\n[2] Calling vision model...")
|
||
raw = call_vision(prompt, garment_b64, canvas_b64)
|
||
|
||
# 解析
|
||
print("\n[3] Model response:")
|
||
print(raw[:2000])
|
||
print("\n" + "="*60)
|
||
|
||
result = parse_result(raw)
|
||
|
||
if result:
|
||
print("\n[4] Parsed result:")
|
||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||
|
||
# 评估
|
||
print("\n[5] Quality check:")
|
||
issues = evaluate_result(result)
|
||
for issue in issues:
|
||
print(f" {issue}")
|
||
|
||
# ==================== 阶段 2: Gemini 生图 ====================
|
||
if result and result.get("pattern_mode") == "all_over" and result.get("fabric_info"):
|
||
print("\n" + "="*60)
|
||
print("[6] Testing Gemini image generation (all_over fabric)...")
|
||
|
||
fi = result["fabric_info"]
|
||
gen_prompt = (
|
||
f"Please generate a high-resolution flat fabric pattern image based on this garment photo.\n"
|
||
f"\n"
|
||
f"Requirements:\n"
|
||
f"1. Output a RECTANGULAR flat fabric image (not garment shape, just the fabric laid flat)\n"
|
||
f"2. Fabric type: {fi.get('fabric_type', 'digital print')}\n"
|
||
f"3. Pattern elements: {fi.get('pattern_elements', 'floral')}\n"
|
||
f"4. Pattern direction: {fi.get('pattern_direction', 'uniform')}\n"
|
||
f"5. Color transition: {fi.get('color_transition', 'none')}\n"
|
||
f"6. Pattern density: {fi.get('density', 'medium')}\n"
|
||
f"7. Remove all garment wrinkles/shadows - show flat fabric only\n"
|
||
f"8. The pattern must be clear, colors accurate, suitable for garment piece overlay\n"
|
||
f"9. Aspect ratio: 3:4\n"
|
||
f"10. The image should look like a piece of fabric photographed from directly above on a flat surface"
|
||
)
|
||
|
||
print(f" Prompt: {gen_prompt[:200]}...")
|
||
print(f" Model: gemini-3-pro-image-preview")
|
||
|
||
# 用 Gemini OpenAI 兼容方式生图
|
||
from openai import OpenAI as GeminiClient
|
||
gemini = GeminiClient(
|
||
api_key=settings.GEMINI_API_KEY,
|
||
base_url=f"{settings.GEMINI_BASE_URL}/v1",
|
||
)
|
||
|
||
content_parts = [
|
||
{"type": "text", "text": gen_prompt},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{garment_b64}"}},
|
||
]
|
||
|
||
print(" Calling Gemini 3 Pro Image...")
|
||
import time
|
||
t0 = time.time()
|
||
|
||
completion = gemini.chat.completions.create(
|
||
model="gemini-3-pro-image-preview",
|
||
messages=[{"role": "user", "content": content_parts}],
|
||
)
|
||
|
||
elapsed = time.time() - t0
|
||
resp_content = completion.choices[0].message.content or ""
|
||
print(f" Response length: {len(resp_content)} chars, time: {elapsed:.1f}s")
|
||
|
||
# 提取图片
|
||
import re
|
||
match = re.search(r'!\[.*?\]\((data:image/(\w+);base64,([^)]+))\)', resp_content)
|
||
if not match:
|
||
match = re.search(r'(data:image/(\w+);base64,([A-Za-z0-9+/=]+))', resp_content)
|
||
|
||
if match:
|
||
img_format = match.group(2)
|
||
img_b64 = match.group(3)
|
||
# 修复 padding
|
||
padding = 4 - len(img_b64) % 4
|
||
if padding != 4:
|
||
img_b64 += '=' * padding
|
||
|
||
# 保存
|
||
out_path = f"debug_images/test_fabric_allover.{img_format}"
|
||
with open(out_path, "wb") as f:
|
||
f.write(base64.b64decode(img_b64))
|
||
print(f" [OK] Fabric image saved: {out_path} ({len(img_b64)//1024}KB)")
|
||
else:
|
||
print(f" [FAIL] No image in response")
|
||
if resp_content:
|
||
print(f" Response preview: {resp_content[:300]}")
|
||
|
||
print(f"\n{'='*60}")
|
||
print("Done.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|