Files
DP/Server/app/api/v1/ai_pattern.py

477 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
AI 图案生成接口
功能:生成预览图、面料图、裁切、细化
"""
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from typing import Optional
import json, base64, logging
from app.core.config import settings
from app.core.security import get_current_user
from app.db import get_db
from sqlalchemy.orm import Session
from app.models.user import User
from app.api.v1.ai_llm import (
call_image_model,
call_image_model_batch,
get_runtime_ai_config,
has_ai_config,
)
log = logging.getLogger(__name__)
if not log.handlers:
log.setLevel(logging.INFO)
log.propagate = False
_h = logging.StreamHandler()
_h.setFormatter(
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
)
log.addHandler(_h)
router = APIRouter()
# ==================== 数据模型 ====================
class GeneratePatternRequest(BaseModel):
image_base64: str
canvas_base64: Optional[str] = None
prompt: Optional[str] = None
image_edit_model: Optional[str] = None
class GenerateFabricRequest(BaseModel):
image_base64: str
fabric_info: dict
image_edit_model: Optional[str] = None
class GenerateDesignImagesRequest(BaseModel):
prompt: str
reference_image_base64: Optional[str] = None
count: int = 1
size: Optional[str] = "2K"
image_edit_model: Optional[str] = None
class CropPieceRequest(BaseModel):
preview_base64: str
canvas_width: int
canvas_height: int
piece_left: float
piece_top: float
piece_width: float
piece_height: float
piece_name: str
padding: float = 0.15
class RefinePieceRequest(BaseModel):
cropped_base64: str
piece_name: str
pattern_type: str = "fill_pattern"
aspect_ratio: Optional[str] = None
piece_width: Optional[int] = None
piece_height: Optional[int] = None
prompt: Optional[str] = None
image_edit_model: Optional[str] = None
class ExtractPieceRequest(BaseModel):
preview_base64: str
piece_name: str
piece_description: str = ""
prompt: Optional[str] = None
# ==================== 图案生成接口 ====================
@router.post("/ai/generate-preview")
async def generate_preview(
data: GeneratePatternRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""阶段1成衣照片 + 裁片截图 → AI 生成带轮廓的花样预览图"""
log.info(f"{'=' * 60}")
log.info(
f"[Preview] 收到预览请求, 成衣图: {len(data.image_base64) // 1024}KB, 裁片图: {len(data.canvas_base64) // 1024 if data.canvas_base64 else 0}KB"
)
if data.prompt:
log.info(f"[Preview] 自定义提示词: {data.prompt[:150]}")
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_image_model = (
data.image_edit_model
or (user.ai_image_model if user else None)
or settings.AI_IMAGE_EDIT_MODEL
)
if not has_ai_config(runtime_ai_config) or not effective_image_model:
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
try:
images = [data.image_base64]
if data.canvas_base64:
images.append(data.canvas_base64)
default_prompt = (
(
"Image 1 是一件成衣照片Image 2 是裁片轮廓图。"
"请将 Image 1 成衣上的面料花样提取出来,填充到 Image 2 的每个裁片轮廓内。"
"要求:保持花样颜色、比例、方向一致;保留裁片轮廓线;不要改变裁片排列位置。"
)
if data.canvas_base64
else ("请从这件成衣中提取面料花样,生成一张干净的花样平铺图。")
)
result_url, desc = call_image_model(
images_b64=images,
prompt=data.prompt or default_prompt,
model_override=effective_image_model,
runtime_config=runtime_ai_config,
)
return {"code": 200, "data": {"image_url": result_url, "description": desc}}
except Exception as e:
log.error(f"预览生成失败: {e}", exc_info=True)
err_msg = str(e)
if "data_inspection_failed" in err_msg or "inappropriate" in err_msg:
raise HTTPException(
400,
"图片内容未通过平台安全审核,请更换图片后重试(避免使用含版权角色/敏感内容的图片)",
)
raise HTTPException(500, f"预览生成失败: {err_msg}")
@router.post("/ai/generate-fabric")
async def generate_fabric(
data: GenerateFabricRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""all_over 模式:根据成衣照片 + fabric_info 生成面料平铺图"""
log.info(f"{'=' * 60}")
log.info(f"[Fabric] 生成面料图, 成衣: {len(data.image_base64) // 1024}KB")
log.info(
f"[Fabric] fabric_info: {json.dumps(data.fabric_info, ensure_ascii=False)}"
)
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_image_model = (
data.image_edit_model
or (user.ai_image_model if user else None)
or settings.AI_IMAGE_EDIT_MODEL
)
if not has_ai_config(runtime_ai_config) or not effective_image_model:
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
try:
fi = data.fabric_info
prompt = (
f"Please generate a high-resolution flat fabric pattern image based on this garment photo.\n"
f"\nRequirements:\n"
f"1. Output a RECTANGULAR flat fabric image (not garment shape, just the fabric laid flat)\n"
f"2. Fabric type: {fi.get('fabric_type', 'digital print')}\n"
f"3. Pattern elements: {fi.get('pattern_elements', 'floral')}\n"
f"4. Pattern direction: {fi.get('pattern_direction', 'uniform')}\n"
f"5. Color transition: {fi.get('color_transition', 'none')}\n"
f"6. Pattern density: {fi.get('density', 'medium')}\n"
f"7. Remove all garment wrinkles/shadows - show flat fabric only\n"
f"8. The pattern must be clear, colors accurate, suitable for garment piece overlay\n"
f"9. Aspect ratio: 3:4\n"
f"10. The image should look like a piece of fabric photographed from directly above on a flat surface"
)
log.info(f"[Fabric] 提示词: {prompt[:200]}")
result_url, desc = call_image_model(
images_b64=[data.image_base64],
prompt=prompt,
model_override=effective_image_model,
runtime_config=runtime_ai_config,
)
log.info(f"[Fabric] 面料图生成完成")
return {"code": 200, "data": {"fabric_url": result_url, "description": desc}}
except Exception as e:
log.error(f"面料图生成失败: {e}", exc_info=True)
raise HTTPException(500, f"面料图生成失败: {str(e)}")
@router.post("/ai/generate-design-images")
async def generate_design_images(
data: GenerateDesignImagesRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""通用豆包出图:支持纯文生图和基于参考图的连续组图。"""
log.info(f"{'=' * 60}")
log.info(
f"[DesignImage] 通用出图, 参考图={'' if data.reference_image_base64 else ''}, count={data.count}, size={data.size or '2K'}"
)
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_image_model = (
data.image_edit_model
or (user.ai_image_model if user else None)
or settings.AI_IMAGE_EDIT_MODEL
)
if not has_ai_config(runtime_ai_config) or not effective_image_model:
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
prompt = (data.prompt or "").strip()
if not prompt:
raise HTTPException(400, "prompt 不能为空")
try:
count = max(1, min(int(data.count or 1), 4))
images_b64 = [data.reference_image_base64] if data.reference_image_base64 else []
result = call_image_model_batch(
images_b64=images_b64,
prompt=prompt,
model_override=effective_image_model,
size=data.size or "2K",
max_images=count,
stream=(count > 1),
runtime_config=runtime_ai_config,
)
images = result.get("images", [])
return {
"code": 200,
"data": {
"images": images,
"image_urls": [item["url"] for item in images],
"count": len(images),
"used_reference_image": bool(data.reference_image_base64),
"usage": result.get("usage"),
},
}
except Exception as e:
log.error(f"通用出图失败: {e}", exc_info=True)
err_msg = str(e)
if "data_inspection_failed" in err_msg or "inappropriate" in err_msg:
raise HTTPException(400, "图片内容未通过平台安全审核,请更换提示词或参考图")
raise HTTPException(500, f"通用出图失败: {err_msg}")
@router.post("/ai/crop-piece")
async def crop_piece(
data: CropPieceRequest, current_username: str = Depends(get_current_user)
):
"""从预览图中按坐标裁切指定裁片区域Pillow像素级精准"""
log.info(f"{'=' * 60}")
log.info(f"[CropPiece] 裁切: {data.piece_name}")
log.info(
f"[CropPiece] 画布: {data.canvas_width}x{data.canvas_height}, 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}"
)
try:
result_b64 = _crop_piece_from_preview(data)
return {
"code": 200,
"data": {"cropped_base64": result_b64, "piece_name": data.piece_name},
}
except Exception as e:
log.error(f"裁切失败: {e}", exc_info=True)
raise HTTPException(500, f"裁切失败: {str(e)}")
@router.post("/ai/refine-piece")
async def refine_piece(
data: RefinePieceRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""AI 细化裁切图:根据图案类型生成不同提示词"""
log.info(f"{'=' * 60}")
log.info(
f"[RefinePiece] 细化: {data.piece_name}, 类型: {data.pattern_type}, 比例: {data.aspect_ratio}"
)
log.info(
f"[RefinePiece] 裁片尺寸: {data.piece_width}x{data.piece_height}, 图片: {len(data.cropped_base64) // 1024}KB"
)
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_image_model = (
data.image_edit_model
or (user.ai_image_model if user else None)
or settings.AI_IMAGE_EDIT_MODEL
)
if not has_ai_config(runtime_ai_config) or not effective_image_model:
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
try:
size_hint = ""
if data.aspect_ratio:
size_hint = f"输出图片宽高比必须为 {data.aspect_ratio}"
if data.piece_width and data.piece_height:
size_hint += f"输出尺寸至少 {int(data.piece_width * 1.15)}x{int(data.piece_height * 1.15)} 像素。"
if data.pattern_type == "fill_pattern":
prompt = data.prompt or (
f"这是从服装裁片中裁切出来的花型图案区域。"
f"请将它处理为一张干净的矩形花型图:"
f"去掉所有轮廓线、标签文字和边缘空白;"
f"花型图案要完整铺满整个矩形,无空白边距;"
f"向四周自然延展重复纹样,保持花型连续无接缝;"
f"保持花样的颜色和细节清晰。{size_hint}"
)
elif data.pattern_type == "theme_pattern":
prompt = data.prompt or (
f"这是从服装裁片中裁切出来的主题图案区域如卡通人物、Logo、印花"
f"请提取其中的主题图案,输出为纯白色背景(#FFFFFF)上的主题图案:"
f"主题图案要完整清晰,保持原始颜色和细节;"
f"背景必须是纯白色(#FFFFFF),不要任何其他底色或纹理;"
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
)
elif data.pattern_type == "mixed_fill":
prompt = data.prompt or (
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
f"请只提取底部的花型/纹理图案,去掉上面的主题图案(人物/Logo"
f"用底纹自然填充主题图案原来的位置;"
f"花型要完整铺满整个矩形,保持纹样连续;"
f"去掉轮廓线和标签文字。{size_hint}"
)
elif data.pattern_type == "mixed_theme":
prompt = data.prompt or (
f"这是从服装裁片中裁切出来的区域,包含主题图案和花型底纹。"
f"请只提取主题图案(人物/Logo/大面积印花),输出为纯白色背景(#FFFFFF)"
f"主题图案要完整清晰,保持原始颜色和细节;"
f"背景必须是纯白色(#FFFFFF),去掉所有底纹;"
f"去掉轮廓线和标签文字;主题图案在画面中居中放置。{size_hint}"
)
else:
prompt = (
data.prompt
or f"细化这张图片,去掉轮廓线和标签文字,保持清晰。{size_hint}"
)
log.info(f"[RefinePiece] 提示词: {prompt[:200]}")
result_url, desc = call_image_model(
images_b64=[data.cropped_base64],
prompt=prompt,
model_override=effective_image_model,
runtime_config=runtime_ai_config,
)
return {
"code": 200,
"data": {
"refined_url": result_url,
"piece_name": data.piece_name,
"pattern_type": data.pattern_type,
},
}
except Exception as e:
log.error(f"细化失败: {e}", exc_info=True)
err_msg = str(e)
if "data_inspection_failed" in err_msg:
raise HTTPException(400, "图片内容未通过安全审核")
raise HTTPException(500, f"细化失败: {err_msg}")
@router.post("/ai/extract-piece-pattern")
async def extract_piece_pattern(
data: ExtractPieceRequest,
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""阶段2从预览图中提取指定裁片区域 → 输出矩形花样图"""
log.info(f"{'=' * 60}")
log.info(
f"[ExtractPiece] 裁片: {data.piece_name}, 描述: '{data.piece_description}', 预览图: {len(data.preview_base64) // 1024}KB"
)
user = db.query(User).filter(User.username == current_username).first()
runtime_ai_config = get_runtime_ai_config(user)
effective_image_model = (
getattr(data, "image_edit_model", None)
or (user.ai_image_model if user else None)
or settings.AI_IMAGE_EDIT_MODEL
)
if not has_ai_config(runtime_ai_config) or not effective_image_model:
raise HTTPException(400, "请先在 API 配置页填写转发 Key或联系管理员配置默认 AI_API_KEY")
try:
prompt = data.prompt or (
f"这是一张服装裁片花样预览图,包含多个裁片。"
f"请提取其中【{data.piece_name}】的花样区域"
f"{'' + data.piece_description + '' if data.piece_description else ''}"
f"将该区域的花样输出为一张完整的矩形图片。"
f"要求:只保留花样内容,去掉轮廓线,填满整个矩形,保持清晰。"
)
result_url, desc = call_image_model(
images_b64=[data.preview_base64],
prompt=prompt,
model_override=effective_image_model,
runtime_config=runtime_ai_config,
)
return {
"code": 200,
"data": {
"pattern_url": result_url,
"piece_name": data.piece_name,
"description": desc,
},
}
except Exception as e:
log.error(f"裁片提取失败: {e}", exc_info=True)
raise HTTPException(500, f"裁片提取失败: {str(e)}")
# ==================== 工具函数 ====================
def _crop_piece_from_preview(data: CropPieceRequest) -> str:
"""用 Pillow 从预览图中按 PS 坐标裁切指定区域(带出血线)"""
from PIL import Image
import io
img_bytes = base64.b64decode(data.preview_base64)
preview_img = Image.open(io.BytesIO(img_bytes))
pw, ph = preview_img.size
scale_x = pw / data.canvas_width
scale_y = ph / data.canvas_height
bleed_x = data.piece_width * data.padding
bleed_y = data.piece_height * data.padding
left = max(0, (data.piece_left - bleed_x) * scale_x)
top = max(0, (data.piece_top - bleed_y) * scale_y)
right = min(pw, (data.piece_left + data.piece_width + bleed_x) * scale_x)
bottom = min(ph, (data.piece_top + data.piece_height + bleed_y) * scale_y)
log.info(f"[CropPiece] PS 画布: {data.canvas_width}x{data.canvas_height}")
log.info(f"[CropPiece] 预览图: {pw}x{ph}")
log.info(f"[CropPiece] 缩放比: x={scale_x:.4f}, y={scale_y:.4f}")
log.info(
f"[CropPiece] PS 裁片区域: ({data.piece_left},{data.piece_top}) {data.piece_width}x{data.piece_height}"
)
log.info(
f"[CropPiece] 出血线: {data.padding * 100:.0f}% → x±{bleed_x:.0f}px, y±{bleed_y:.0f}px"
)
log.info(
f"[CropPiece] 预览图裁切: ({left:.0f},{top:.0f})-({right:.0f},{bottom:.0f}) = {right - left:.0f}x{bottom - top:.0f}px"
)
cropped = preview_img.crop((int(left), int(top), int(right), int(bottom)))
import os, time
debug_dir = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "debug_images"
)
os.makedirs(debug_dir, exist_ok=True)
ts = int(time.time())
debug_path = os.path.join(debug_dir, f"{ts}_crop_{data.piece_name}.png")
cropped.save(debug_path)
log.info(f"[CropPiece] 裁切结果已保存: {debug_path}")
buf = io.BytesIO()
cropped.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()