feat: expand AI workflow support and refresh docs
This commit is contained in:
476
Server/app/api/v1/ai_pattern.py
Normal file
476
Server/app/api/v1/ai_pattern.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AI 图案生成接口
|
||||
功能:生成预览图、面料图、裁切、细化
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import json, base64, logging
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_current_user
|
||||
from app.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user import User
|
||||
from app.api.v1.ai_llm import (
|
||||
call_image_model,
|
||||
call_image_model_batch,
|
||||
get_runtime_ai_config,
|
||||
has_ai_config,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
if not log.handlers:
|
||||
log.setLevel(logging.INFO)
|
||||
log.propagate = False
|
||||
_h = logging.StreamHandler()
|
||||
_h.setFormatter(
|
||||
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
|
||||
)
|
||||
log.addHandler(_h)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
|
||||
class GeneratePatternRequest(BaseModel):
|
||||
image_base64: str
|
||||
canvas_base64: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
image_edit_model: Optional[str] = None
|
||||
|
||||
|
||||
class GenerateFabricRequest(BaseModel):
|
||||
image_base64: str
|
||||
fabric_info: dict
|
||||
image_edit_model: Optional[str] = None
|
||||
|
||||
|
||||
class GenerateDesignImagesRequest(BaseModel):
|
||||
prompt: str
|
||||
reference_image_base64: Optional[str] = None
|
||||
count: int = 1
|
||||
size: Optional[str] = "2K"
|
||||
image_edit_model: Optional[str] = None
|
||||
|
||||
|
||||
class CropPieceRequest(BaseModel):
|
||||
preview_base64: str
|
||||
canvas_width: int
|
||||
canvas_height: int
|
||||
piece_left: float
|
||||
piece_top: float
|
||||
piece_width: float
|
||||
piece_height: float
|
||||
piece_name: str
|
||||
padding: float = 0.15
|
||||
|
||||
|
||||
class RefinePieceRequest(BaseModel):
|
||||
cropped_base64: str
|
||||
piece_name: str
|
||||
pattern_type: str = "fill_pattern"
|
||||
aspect_ratio: Optional[str] = None
|
||||
piece_width: Optional[int] = None
|
||||
piece_height: Optional[int] = None
|
||||
prompt: Optional[str] = None
|
||||
image_edit_model: Optional[str] = None
|
||||
|
||||
|
||||
class ExtractPieceRequest(BaseModel):
|
||||
preview_base64: str
|
||||
piece_name: str
|
||||
piece_description: str = ""
|
||||
prompt: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 图案生成接口 ====================
|
||||
|
||||
|
||||
@router.post("/ai/generate-preview")
|
||||
async def generate_preview(
|
||||
data: GeneratePatternRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""阶段1:成衣照片 + 裁片截图 → AI 生成带轮廓的花样预览图"""
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(
|
||||
f"[Preview] 收到预览请求, 成衣图: {len(data.image_base64) // 1024}KB, 裁片图: {len(data.canvas_base64) // 1024 if data.canvas_base64 else 0}KB"
|
||||
)
|
||||
if data.prompt:
|
||||
log.info(f"[Preview] 自定义提示词: {data.prompt[:150]}")
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_image_model = (
|
||||
data.image_edit_model
|
||||
or (user.ai_image_model if user else None)
|
||||
or settings.AI_IMAGE_EDIT_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config) or not effective_image_model:
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
try:
|
||||
images = [data.image_base64]
|
||||
if data.canvas_base64:
|
||||
images.append(data.canvas_base64)
|
||||
|
||||
default_prompt = (
|
||||
(
|
||||
"Image 1 是一件成衣照片,Image 2 是裁片轮廓图。"
|
||||
"请将 Image 1 成衣上的面料花样提取出来,填充到 Image 2 的每个裁片轮廓内。"
|
||||
"要求:保持花样颜色、比例、方向一致;保留裁片轮廓线;不要改变裁片排列位置。"
|
||||
)
|
||||
if data.canvas_base64
|
||||
else ("请从这件成衣中提取面料花样,生成一张干净的花样平铺图。")
|
||||
)
|
||||
|
||||
result_url, desc = call_image_model(
|
||||
images_b64=images,
|
||||
prompt=data.prompt or default_prompt,
|
||||
model_override=effective_image_model,
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
return {"code": 200, "data": {"image_url": result_url, "description": desc}}
|
||||
except Exception as e:
|
||||
log.error(f"预览生成失败: {e}", exc_info=True)
|
||||
err_msg = str(e)
|
||||
if "data_inspection_failed" in err_msg or "inappropriate" in err_msg:
|
||||
raise HTTPException(
|
||||
400,
|
||||
"图片内容未通过平台安全审核,请更换图片后重试(避免使用含版权角色/敏感内容的图片)",
|
||||
)
|
||||
raise HTTPException(500, f"预览生成失败: {err_msg}")
|
||||
|
||||
|
||||
@router.post("/ai/generate-fabric")
|
||||
async def generate_fabric(
|
||||
data: GenerateFabricRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""all_over 模式:根据成衣照片 + fabric_info 生成面料平铺图"""
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(f"[Fabric] 生成面料图, 成衣: {len(data.image_base64) // 1024}KB")
|
||||
log.info(
|
||||
f"[Fabric] fabric_info: {json.dumps(data.fabric_info, ensure_ascii=False)}"
|
||||
)
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_image_model = (
|
||||
data.image_edit_model
|
||||
or (user.ai_image_model if user else None)
|
||||
or settings.AI_IMAGE_EDIT_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config) or not effective_image_model:
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
try:
|
||||
fi = data.fabric_info
|
||||
prompt = (
|
||||
f"Please generate a high-resolution flat fabric pattern image based on this garment photo.\n"
|
||||
f"\nRequirements:\n"
|
||||
f"1. Output a RECTANGULAR flat fabric image (not garment shape, just the fabric laid flat)\n"
|
||||
f"2. Fabric type: {fi.get('fabric_type', 'digital print')}\n"
|
||||
f"3. Pattern elements: {fi.get('pattern_elements', 'floral')}\n"
|
||||
f"4. Pattern direction: {fi.get('pattern_direction', 'uniform')}\n"
|
||||
f"5. Color transition: {fi.get('color_transition', 'none')}\n"
|
||||
f"6. Pattern density: {fi.get('density', 'medium')}\n"
|
||||
f"7. Remove all garment wrinkles/shadows - show flat fabric only\n"
|
||||
f"8. The pattern must be clear, colors accurate, suitable for garment piece overlay\n"
|
||||
f"9. Aspect ratio: 3:4\n"
|
||||
f"10. The image should look like a piece of fabric photographed from directly above on a flat surface"
|
||||
)
|
||||
|
||||
log.info(f"[Fabric] 提示词: {prompt[:200]}")
|
||||
|
||||
result_url, desc = call_image_model(
|
||||
images_b64=[data.image_base64],
|
||||
prompt=prompt,
|
||||
model_override=effective_image_model,
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
|
||||
log.info(f"[Fabric] 面料图生成完成")
|
||||
return {"code": 200, "data": {"fabric_url": result_url, "description": desc}}
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"面料图生成失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"面料图生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ai/generate-design-images")
|
||||
async def generate_design_images(
|
||||
data: GenerateDesignImagesRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""通用豆包出图:支持纯文生图和基于参考图的连续组图。"""
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(
|
||||
f"[DesignImage] 通用出图, 参考图={'是' if data.reference_image_base64 else '否'}, count={data.count}, size={data.size or '2K'}"
|
||||
)
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_image_model = (
|
||||
data.image_edit_model
|
||||
or (user.ai_image_model if user else None)
|
||||
or settings.AI_IMAGE_EDIT_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config) or not effective_image_model:
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
prompt = (data.prompt or "").strip()
|
||||
if not prompt:
|
||||
raise HTTPException(400, "prompt 不能为空")
|
||||
|
||||
try:
|
||||
count = max(1, min(int(data.count or 1), 4))
|
||||
images_b64 = [data.reference_image_base64] if data.reference_image_base64 else []
|
||||
result = call_image_model_batch(
|
||||
images_b64=images_b64,
|
||||
prompt=prompt,
|
||||
model_override=effective_image_model,
|
||||
size=data.size or "2K",
|
||||
max_images=count,
|
||||
stream=(count > 1),
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
images = result.get("images", [])
|
||||
return {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"images": images,
|
||||
"image_urls": [item["url"] for item in images],
|
||||
"count": len(images),
|
||||
"used_reference_image": bool(data.reference_image_base64),
|
||||
"usage": result.get("usage"),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
log.error(f"通用出图失败: {e}", exc_info=True)
|
||||
err_msg = str(e)
|
||||
if "data_inspection_failed" in err_msg or "inappropriate" in err_msg:
|
||||
raise HTTPException(400, "图片内容未通过平台安全审核,请更换提示词或参考图")
|
||||
raise HTTPException(500, f"通用出图失败: {err_msg}")
|
||||
|
||||
|
||||
@router.post("/ai/crop-piece")
|
||||
async def crop_piece(
|
||||
data: CropPieceRequest, current_username: str = Depends(get_current_user)
|
||||
):
|
||||
"""从预览图中按坐标裁切指定裁片区域(Pillow,像素级精准)"""
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(f"[CropPiece] 裁切: {data.piece_name}")
|
||||
log.info(
|
||||
f"[CropPiece] 画布: {data.canvas_width}x{data.canvas_height}, 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}"
|
||||
)
|
||||
|
||||
try:
|
||||
result_b64 = _crop_piece_from_preview(data)
|
||||
return {
|
||||
"code": 200,
|
||||
"data": {"cropped_base64": result_b64, "piece_name": data.piece_name},
|
||||
}
|
||||
except Exception as e:
|
||||
log.error(f"裁切失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"裁切失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/ai/refine-piece")
|
||||
async def refine_piece(
|
||||
data: RefinePieceRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""AI 细化裁切图:根据图案类型生成不同提示词"""
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(
|
||||
f"[RefinePiece] 细化: {data.piece_name}, 类型: {data.pattern_type}, 比例: {data.aspect_ratio}"
|
||||
)
|
||||
log.info(
|
||||
f"[RefinePiece] 裁片尺寸: {data.piece_width}x{data.piece_height}, 图片: {len(data.cropped_base64) // 1024}KB"
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_image_model = (
|
||||
data.image_edit_model
|
||||
or (user.ai_image_model if user else None)
|
||||
or settings.AI_IMAGE_EDIT_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config) or not effective_image_model:
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
try:
|
||||
size_hint = ""
|
||||
if data.aspect_ratio:
|
||||
size_hint = f"输出图片宽高比必须为 {data.aspect_ratio}。"
|
||||
if data.piece_width and data.piece_height:
|
||||
size_hint += f"输出尺寸至少 {int(data.piece_width * 1.15)}x{int(data.piece_height * 1.15)} 像素。"
|
||||
|
||||
if data.pattern_type == "fill_pattern":
|
||||
prompt = data.prompt or (
|
||||
f"这是从服装裁片中裁切出来的花型图案区域。"
|
||||
f"请将它处理为一张干净的矩形花型图:"
|
||||
f"去掉所有轮廓线、标签文字和边缘空白;"
|
||||
f"花型图案要完整铺满整个矩形,无空白边距;"
|
||||
f"向四周自然延展重复纹样,保持花型连续无接缝;"
|
||||
f"保持花样的颜色和细节清晰。{size_hint}"
|
||||
)
|
||||
elif data.pattern_type == "theme_pattern":
|
||||
prompt = data.prompt or (
|
||||
f"这是从服装裁片中裁切出来的主题图案区域(如卡通人物、Logo、印花)。"
|
||||
f"请提取其中的主题图案,输出为纯白色背景(#FFFFFF)上的主题图案:"
|
||||
f"主题图案要完整清晰,保持原始颜色和细节;"
|
||||
f"背景必须是纯白色(#FFFFFF),不要任何其他底色或纹理;"
|
||||
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
|
||||
)
|
||||
elif data.pattern_type == "mixed_fill":
|
||||
prompt = data.prompt or (
|
||||
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
|
||||
f"请只提取底部的花型/纹理图案,去掉上面的主题图案(人物/Logo):"
|
||||
f"用底纹自然填充主题图案原来的位置;"
|
||||
f"花型要完整铺满整个矩形,保持纹样连续;"
|
||||
f"去掉轮廓线和标签文字。{size_hint}"
|
||||
)
|
||||
elif data.pattern_type == "mixed_theme":
|
||||
prompt = data.prompt or (
|
||||
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
|
||||
f"请只提取主题图案(人物/Logo/大面积印花),输出为纯白色背景(#FFFFFF):"
|
||||
f"主题图案要完整清晰,保持原始颜色和细节;"
|
||||
f"背景必须是纯白色(#FFFFFF),去掉所有底纹;"
|
||||
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
data.prompt
|
||||
or f"细化这张图片,去掉轮廓线和标签文字,保持清晰。{size_hint}"
|
||||
)
|
||||
|
||||
log.info(f"[RefinePiece] 提示词: {prompt[:200]}")
|
||||
|
||||
result_url, desc = call_image_model(
|
||||
images_b64=[data.cropped_base64],
|
||||
prompt=prompt,
|
||||
model_override=effective_image_model,
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
return {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"refined_url": result_url,
|
||||
"piece_name": data.piece_name,
|
||||
"pattern_type": data.pattern_type,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
log.error(f"细化失败: {e}", exc_info=True)
|
||||
err_msg = str(e)
|
||||
if "data_inspection_failed" in err_msg:
|
||||
raise HTTPException(400, "图片内容未通过安全审核")
|
||||
raise HTTPException(500, f"细化失败: {err_msg}")
|
||||
|
||||
|
||||
@router.post("/ai/extract-piece-pattern")
|
||||
async def extract_piece_pattern(
|
||||
data: ExtractPieceRequest,
|
||||
current_username: str = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""阶段2:从预览图中提取指定裁片区域 → 输出矩形花样图"""
|
||||
log.info(f"{'=' * 60}")
|
||||
log.info(
|
||||
f"[ExtractPiece] 裁片: {data.piece_name}, 描述: '{data.piece_description}', 预览图: {len(data.preview_base64) // 1024}KB"
|
||||
)
|
||||
user = db.query(User).filter(User.username == current_username).first()
|
||||
runtime_ai_config = get_runtime_ai_config(user)
|
||||
effective_image_model = (
|
||||
getattr(data, "image_edit_model", None)
|
||||
or (user.ai_image_model if user else None)
|
||||
or settings.AI_IMAGE_EDIT_MODEL
|
||||
)
|
||||
if not has_ai_config(runtime_ai_config) or not effective_image_model:
|
||||
raise HTTPException(400, "请先在 API 配置页填写转发 Key,或联系管理员配置默认 AI_API_KEY")
|
||||
|
||||
try:
|
||||
prompt = data.prompt or (
|
||||
f"这是一张服装裁片花样预览图,包含多个裁片。"
|
||||
f"请提取其中【{data.piece_name}】的花样区域"
|
||||
f"{'(' + data.piece_description + ')' if data.piece_description else ''},"
|
||||
f"将该区域的花样输出为一张完整的矩形图片。"
|
||||
f"要求:只保留花样内容,去掉轮廓线,填满整个矩形,保持清晰。"
|
||||
)
|
||||
result_url, desc = call_image_model(
|
||||
images_b64=[data.preview_base64],
|
||||
prompt=prompt,
|
||||
model_override=effective_image_model,
|
||||
runtime_config=runtime_ai_config,
|
||||
)
|
||||
return {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"pattern_url": result_url,
|
||||
"piece_name": data.piece_name,
|
||||
"description": desc,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
log.error(f"裁片提取失败: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"裁片提取失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def _crop_piece_from_preview(data: CropPieceRequest) -> str:
|
||||
"""用 Pillow 从预览图中按 PS 坐标裁切指定区域(带出血线)"""
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
img_bytes = base64.b64decode(data.preview_base64)
|
||||
preview_img = Image.open(io.BytesIO(img_bytes))
|
||||
pw, ph = preview_img.size
|
||||
|
||||
scale_x = pw / data.canvas_width
|
||||
scale_y = ph / data.canvas_height
|
||||
|
||||
bleed_x = data.piece_width * data.padding
|
||||
bleed_y = data.piece_height * data.padding
|
||||
|
||||
left = max(0, (data.piece_left - bleed_x) * scale_x)
|
||||
top = max(0, (data.piece_top - bleed_y) * scale_y)
|
||||
right = min(pw, (data.piece_left + data.piece_width + bleed_x) * scale_x)
|
||||
bottom = min(ph, (data.piece_top + data.piece_height + bleed_y) * scale_y)
|
||||
|
||||
log.info(f"[CropPiece] PS 画布: {data.canvas_width}x{data.canvas_height}")
|
||||
log.info(f"[CropPiece] 预览图: {pw}x{ph}")
|
||||
log.info(f"[CropPiece] 缩放比: x={scale_x:.4f}, y={scale_y:.4f}")
|
||||
log.info(
|
||||
f"[CropPiece] PS 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}"
|
||||
)
|
||||
log.info(
|
||||
f"[CropPiece] 出血线: {data.padding * 100:.0f}% → x±{bleed_x:.0f}px, y±{bleed_y:.0f}px"
|
||||
)
|
||||
log.info(
|
||||
f"[CropPiece] 预览图裁切: ({left:.0f},{top:.0f})-({right:.0f},{bottom:.0f}) = {right - left:.0f}x{bottom - top:.0f}px"
|
||||
)
|
||||
|
||||
cropped = preview_img.crop((int(left), int(top), int(right), int(bottom)))
|
||||
|
||||
import os, time
|
||||
|
||||
debug_dir = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..", "..", "debug_images"
|
||||
)
|
||||
os.makedirs(debug_dir, exist_ok=True)
|
||||
ts = int(time.time())
|
||||
debug_path = os.path.join(debug_dir, f"{ts}_crop_{data.piece_name}.png")
|
||||
cropped.save(debug_path)
|
||||
log.info(f"[CropPiece] 裁切结果已保存: {debug_path}")
|
||||
|
||||
buf = io.BytesIO()
|
||||
cropped.save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode()
|
||||
Reference in New Issue
Block a user