feat: AI套图分层方案 + Gemini集成 - 4种图案类型处理 + 正片叠底 + 宽高比 + 模型选择
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
261
PltService/main.py
Normal file
261
PltService/main.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
PLT 裁片处理微服务
|
||||
独立部署到阿里云 SAE
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from shapely import affinity
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from pltreader import PltReader
|
||||
|
||||
app = FastAPI(title="PLT Processing Service")
|
||||
|
||||
# CORS 配置
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
class PieceInfo(BaseModel):
|
||||
"""单个裁片信息"""
|
||||
size: str
|
||||
image_base64: str
|
||||
width_px: int
|
||||
height_px: int
|
||||
width_cm: float
|
||||
height_cm: float
|
||||
center_x_cm: float
|
||||
center_y_cm: float
|
||||
left_cm: float
|
||||
top_cm: float
|
||||
|
||||
class GroupInfo(BaseModel):
|
||||
"""裁片分组信息"""
|
||||
group_id: int
|
||||
pieces: List[PieceInfo]
|
||||
|
||||
class ProcessPltResponse(BaseModel):
|
||||
"""API 响应"""
|
||||
success: bool
|
||||
total_groups: int
|
||||
groups: List[GroupInfo]
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
def parse_plt_file(file_content: str, tolerance: int = 10):
|
||||
"""解析 PLT 文件内容"""
|
||||
reader = PltReader(io.StringIO(file_content))
|
||||
output = reader.get_output(tolerance=tolerance)
|
||||
return output
|
||||
|
||||
def apply_rotation_to_image(pil_img: Image.Image, rotation: int) -> Image.Image:
|
||||
"""应用旋转变换"""
|
||||
if rotation == 90:
|
||||
return pil_img.rotate(-90, expand=True)
|
||||
elif rotation == -90:
|
||||
return pil_img.rotate(90, expand=True)
|
||||
elif rotation == 180:
|
||||
return pil_img.rotate(180, expand=True)
|
||||
return pil_img
|
||||
|
||||
def calculate_piece_coordinates(polygon, bounds, plt_to_cm, rotation=0):
|
||||
"""计算裁片坐标(厘米单位)"""
|
||||
min_x, min_y, max_x, max_y = bounds
|
||||
|
||||
centroid = polygon.centroid
|
||||
orig_center_x = (centroid.x - min_x) * plt_to_cm
|
||||
orig_center_y = (centroid.y - min_y) * plt_to_cm
|
||||
|
||||
piece_bounds = polygon.bounds
|
||||
orig_piece_width = (piece_bounds[2] - piece_bounds[0]) * plt_to_cm
|
||||
orig_piece_height = (piece_bounds[3] - piece_bounds[1]) * plt_to_cm
|
||||
|
||||
orig_canvas_width = (max_x - min_x) * plt_to_cm
|
||||
orig_canvas_height = (max_y - min_y) * plt_to_cm
|
||||
|
||||
if rotation == 90:
|
||||
center_x = orig_center_y
|
||||
center_y = orig_canvas_width - orig_center_x
|
||||
piece_width, piece_height = orig_piece_height, orig_piece_width
|
||||
elif rotation == -90:
|
||||
center_x = orig_canvas_height - orig_center_y
|
||||
center_y = orig_center_x
|
||||
piece_width, piece_height = orig_piece_height, orig_piece_width
|
||||
elif rotation == 180:
|
||||
center_x = orig_canvas_width - orig_center_x
|
||||
center_y = orig_canvas_height - orig_center_y
|
||||
piece_width, piece_height = orig_piece_width, orig_piece_height
|
||||
else:
|
||||
center_x, center_y = orig_center_x, orig_center_y
|
||||
piece_width, piece_height = orig_piece_width, orig_piece_height
|
||||
|
||||
left_cm = center_x - piece_width / 2
|
||||
top_cm = center_y - piece_height / 2
|
||||
|
||||
return {
|
||||
"center_x_cm": round(center_x, 2),
|
||||
"center_y_cm": round(center_y, 2),
|
||||
"left_cm": round(left_cm, 2),
|
||||
"top_cm": round(top_cm, 2),
|
||||
"width_cm": round(piece_width, 2),
|
||||
"height_cm": round(piece_height, 2)
|
||||
}
|
||||
|
||||
# ==================== API 接口 ====================
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "service": "plt-processor"}
|
||||
|
||||
@app.post("/process", response_model=ProcessPltResponse)
|
||||
async def process_plt(
|
||||
file: UploadFile = File(..., description="PLT 文件"),
|
||||
size_labels: str = Form(..., description="尺码标签 JSON 数组"),
|
||||
dpi: int = Form(150, description="输出图片分辨率(DPI)"),
|
||||
rotation: int = Form(0, description="强制旋转角度 (0/90/-90/180)")
|
||||
):
|
||||
"""
|
||||
PLT 裁片处理接口
|
||||
"""
|
||||
try:
|
||||
# 1. 解析参数
|
||||
try:
|
||||
size_labels_list = json.loads(size_labels)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="size_labels 格式错误")
|
||||
|
||||
size_num = len(size_labels_list)
|
||||
scale_factor = dpi / 1016
|
||||
plt_to_cm = 2.54 / 1016
|
||||
|
||||
# 2. 读取文件
|
||||
file_content = (await file.read()).decode('utf-8', errors='ignore')
|
||||
|
||||
# 3. 解析 PLT
|
||||
print(f"[PLT] 正在处理文件: {file.filename}")
|
||||
output = parse_plt_file(file_content)
|
||||
|
||||
# 4. 获取尺码聚类
|
||||
clusters = output.get_single_size_info(size_num=size_num)
|
||||
|
||||
# 5. 建立匹配关系
|
||||
base_cluster = [piece for piece in clusters[0] if piece["parent"] is None]
|
||||
match_result = {piece["index"]: [piece["index"]] for piece in base_cluster}
|
||||
|
||||
for size_index in range(1, len(clusters)):
|
||||
compare_cluster = [piece for piece in clusters[size_index] if piece["parent"] is None]
|
||||
|
||||
cost_matrix = np.zeros((len(base_cluster), len(compare_cluster)))
|
||||
|
||||
for i in range(len(base_cluster)):
|
||||
for j in range(len(compare_cluster)):
|
||||
poly1 = base_cluster[i]["data"]
|
||||
poly2 = compare_cluster[j]["data"]
|
||||
|
||||
poly1 = affinity.translate(poly1, xoff=-poly1.centroid.x, yoff=-poly1.centroid.y)
|
||||
poly2 = affinity.translate(poly2, xoff=-poly2.centroid.x, yoff=-poly2.centroid.y)
|
||||
|
||||
target_area = 10000
|
||||
scale1 = (target_area / poly1.area) ** 0.5
|
||||
scale2 = (target_area / poly2.area) ** 0.5
|
||||
poly1 = affinity.scale(poly1, xfact=scale1, yfact=scale1, origin=(0, 0))
|
||||
poly2 = affinity.scale(poly2, xfact=scale2, yfact=scale2, origin=(0, 0))
|
||||
|
||||
dist_min = float('inf')
|
||||
for angle in [0, 90, 180, 270]:
|
||||
rotated = affinity.rotate(poly2, angle, origin='centroid')
|
||||
dist = poly1.hausdorff_distance(rotated)
|
||||
if dist < dist_min:
|
||||
dist_min = dist
|
||||
|
||||
cost_matrix[i, j] = dist_min
|
||||
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
||||
for i, j in zip(row_ind, col_ind):
|
||||
base_index = base_cluster[i]['index']
|
||||
matched_index = compare_cluster[j]['index']
|
||||
match_result[base_index].append(matched_index)
|
||||
|
||||
# 6. 生成图片
|
||||
all_nodes = {node["index"]: node for node in output.nodes}
|
||||
|
||||
from shapely.geometry import MultiPolygon as ShapelyMultiPolygon
|
||||
all_polygons = [node["data"] for node in output.nodes]
|
||||
global_bounds = ShapelyMultiPolygon(all_polygons).bounds
|
||||
|
||||
groups = []
|
||||
|
||||
for group_id, (base_index, matched_indices) in enumerate(match_result.items(), start=1):
|
||||
pieces = []
|
||||
|
||||
for size_idx, piece_index in enumerate(matched_indices):
|
||||
size_label = size_labels_list[size_idx]
|
||||
|
||||
node = all_nodes[piece_index]
|
||||
nodes_to_draw = [node] + node['child']
|
||||
|
||||
img = output._draw_nodes(nodes_to_draw, scale_factor, show_id=False)
|
||||
|
||||
img_rgba = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
|
||||
pil_img = Image.fromarray(img_rgba)
|
||||
pil_img = apply_rotation_to_image(pil_img, rotation)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
pil_img.save(buffer, format='PNG', dpi=(dpi, dpi))
|
||||
buffer.seek(0)
|
||||
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
|
||||
image_base64 = f"data:image/png;base64,{base64_str}"
|
||||
|
||||
coords = calculate_piece_coordinates(
|
||||
node["data"],
|
||||
global_bounds,
|
||||
plt_to_cm,
|
||||
rotation
|
||||
)
|
||||
|
||||
piece_info = PieceInfo(
|
||||
size=size_label,
|
||||
image_base64=image_base64,
|
||||
width_px=pil_img.width,
|
||||
height_px=pil_img.height,
|
||||
**coords
|
||||
)
|
||||
pieces.append(piece_info)
|
||||
|
||||
groups.append(GroupInfo(group_id=group_id, pieces=pieces))
|
||||
|
||||
print(f"[PLT] 处理完成,共 {len(groups)} 组裁片")
|
||||
return ProcessPltResponse(
|
||||
success=True,
|
||||
total_groups=len(groups),
|
||||
groups=groups
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[PLT] 处理失败: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.getenv("PORT", 8080))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
Reference in New Issue
Block a user