507 lines
18 KiB
Python
507 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
PLT 裁片处理接口 (Photoshop 插件专用)
|
||
功能:解析 PLT 文件,生成裁片图片和坐标信息
|
||
优化版:提升匹配计算速度 + 并行图片生成
|
||
"""
|
||
|
||
import os
|
||
import io
|
||
import base64
|
||
import json
|
||
import numpy as np
|
||
import cv2
|
||
from typing import List, Optional, Dict, Tuple
|
||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.orm import Session
|
||
from PIL import Image
|
||
from shapely import affinity
|
||
from scipy.optimize import linear_sum_assignment
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
import asyncio
|
||
|
||
# 导入自定义模块
|
||
import sys
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
||
from pltreader import PltReader
|
||
|
||
from app.db import get_db
|
||
from app.core.security import get_current_user
|
||
from app.core.qiniu_storage import qiniu_storage
|
||
|
||
router = APIRouter()
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class PieceInfo(BaseModel):
|
||
"""单个裁片信息"""
|
||
size: str
|
||
image_base64: Optional[str] = None # base64 图片(未启用云存储时使用)
|
||
image_url: Optional[str] = None # 云存储 URL(启用七牛云时使用)
|
||
width_px: int
|
||
height_px: int
|
||
width_cm: float
|
||
height_cm: float
|
||
center_x_cm: float
|
||
center_y_cm: float
|
||
left_cm: float
|
||
top_cm: float
|
||
|
||
class GroupInfo(BaseModel):
|
||
"""裁片分组信息"""
|
||
group_id: int
|
||
pieces: List[PieceInfo]
|
||
|
||
class MatchInfo(BaseModel):
|
||
"""双文件匹配结果"""
|
||
standard_id: int
|
||
rotated_id: int
|
||
distance: float
|
||
angle: float
|
||
|
||
class SizeMatchInfo(BaseModel):
|
||
"""每个尺码的匹配结果"""
|
||
size: str
|
||
matches: List[MatchInfo]
|
||
|
||
class ProcessPltResponse(BaseModel):
|
||
"""API 响应"""
|
||
success: bool
|
||
total_groups: int
|
||
groups: List[GroupInfo]
|
||
match_analysis: Optional[List[SizeMatchInfo]] = None
|
||
|
||
# ==================== 优化的辅助函数 ====================
|
||
|
||
def parse_plt_file(file_content: str, tolerance: int = 10):
|
||
"""解析 PLT 文件内容"""
|
||
reader = PltReader(io.StringIO(file_content))
|
||
return reader.get_output(tolerance=tolerance)
|
||
|
||
def normalize_polygon(poly):
|
||
"""
|
||
预处理多边形:质心对齐 + 面积归一化
|
||
返回: (归一化多边形, 原始面积, 原始周长)
|
||
"""
|
||
area = poly.area
|
||
perimeter = poly.length
|
||
|
||
if area <= 0:
|
||
return poly, 0, 0
|
||
|
||
# 质心对齐
|
||
poly = affinity.translate(poly, xoff=-poly.centroid.x, yoff=-poly.centroid.y)
|
||
|
||
# 缩放到统一面积
|
||
target_area = 10000
|
||
scale = (target_area / area) ** 0.5
|
||
poly = affinity.scale(poly, xfact=scale, yfact=scale, origin=(0, 0))
|
||
|
||
return poly, area, perimeter
|
||
|
||
def fast_shape_similarity(poly1_data: Tuple, poly2_data: Tuple) -> Tuple[float, int]:
|
||
"""
|
||
快速形状相似度计算
|
||
poly_data: (归一化多边形, 原始面积, 原始周长)
|
||
返回: (距离, 最佳角度)
|
||
"""
|
||
poly1, area1, peri1 = poly1_data
|
||
poly2, area2, peri2 = poly2_data
|
||
|
||
# 面积比初筛(面积差异太大的直接跳过)
|
||
if area1 > 0 and area2 > 0:
|
||
area_ratio = min(area1, area2) / max(area1, area2)
|
||
if area_ratio < 0.5: # 面积差异超过50%
|
||
return float('inf'), 0
|
||
|
||
# 只尝试 0° 和 180°(服装裁片通常只需要这两个角度)
|
||
dist_min = float('inf')
|
||
best_angle = 0
|
||
|
||
for angle in [0, 180]:
|
||
try:
|
||
rotated = affinity.rotate(poly2, angle, origin='centroid') if angle != 0 else poly2
|
||
# 使用 symmetric_difference 面积作为距离度量(比 hausdorff 快很多)
|
||
diff_area = poly1.symmetric_difference(rotated).area
|
||
dist = diff_area / max(poly1.area, 1)
|
||
except:
|
||
dist = float('inf')
|
||
|
||
if dist < dist_min:
|
||
dist_min = dist
|
||
best_angle = angle
|
||
|
||
return dist_min, best_angle
|
||
|
||
def batch_normalize_polygons(pieces: List[Dict]) -> List[Tuple]:
|
||
"""批量预处理多边形"""
|
||
return [normalize_polygon(p["data"]) for p in pieces]
|
||
|
||
def compute_matching_matrix(base_polys: List[Tuple], compare_polys: List[Tuple]) -> np.ndarray:
|
||
"""计算匹配成本矩阵"""
|
||
n, m = len(base_polys), len(compare_polys)
|
||
cost_matrix = np.zeros((n, m))
|
||
|
||
for i in range(n):
|
||
for j in range(m):
|
||
dist, _ = fast_shape_similarity(base_polys[i], compare_polys[j])
|
||
cost_matrix[i, j] = dist
|
||
|
||
return cost_matrix
|
||
|
||
def apply_rotation_to_image(pil_img: Image.Image, rotation: int) -> Image.Image:
|
||
"""应用旋转变换"""
|
||
if rotation == 90:
|
||
return pil_img.rotate(-90, expand=True)
|
||
elif rotation == -90:
|
||
return pil_img.rotate(90, expand=True)
|
||
elif rotation == 180:
|
||
return pil_img.rotate(180, expand=True)
|
||
return pil_img
|
||
|
||
def calculate_piece_coordinates(polygon, bounds, plt_to_cm, rotation=0) -> Dict:
|
||
"""计算裁片坐标(厘米单位)"""
|
||
min_x, min_y, max_x, max_y = bounds
|
||
|
||
centroid = polygon.centroid
|
||
orig_center_x = (centroid.x - min_x) * plt_to_cm
|
||
orig_center_y = (centroid.y - min_y) * plt_to_cm
|
||
|
||
piece_bounds = polygon.bounds
|
||
orig_piece_width = (piece_bounds[2] - piece_bounds[0]) * plt_to_cm
|
||
orig_piece_height = (piece_bounds[3] - piece_bounds[1]) * plt_to_cm
|
||
|
||
orig_canvas_width = (max_x - min_x) * plt_to_cm
|
||
orig_canvas_height = (max_y - min_y) * plt_to_cm
|
||
|
||
if rotation == 90:
|
||
center_x = orig_center_y
|
||
center_y = orig_canvas_width - orig_center_x
|
||
piece_width, piece_height = orig_piece_height, orig_piece_width
|
||
elif rotation == -90:
|
||
center_x = orig_canvas_height - orig_center_y
|
||
center_y = orig_center_x
|
||
piece_width, piece_height = orig_piece_height, orig_piece_width
|
||
elif rotation == 180:
|
||
center_x = orig_canvas_width - orig_center_x
|
||
center_y = orig_canvas_height - orig_center_y
|
||
piece_width, piece_height = orig_piece_width, orig_piece_height
|
||
else:
|
||
center_x, center_y = orig_center_x, orig_center_y
|
||
piece_width, piece_height = orig_piece_width, orig_piece_height
|
||
|
||
left_cm = center_x - piece_width / 2
|
||
top_cm = center_y - piece_height / 2
|
||
|
||
return {
|
||
"center_x_cm": round(center_x, 2),
|
||
"center_y_cm": round(center_y, 2),
|
||
"left_cm": round(left_cm, 2),
|
||
"top_cm": round(top_cm, 2),
|
||
"width_cm": round(piece_width, 2),
|
||
"height_cm": round(piece_height, 2)
|
||
}
|
||
|
||
def get_size_matching(clusters: List, size_num: int) -> Dict[int, List[int]]:
|
||
"""
|
||
获取尺码间匹配关系(优化版)
|
||
返回: {base_index: [matched_indices_per_size]}
|
||
"""
|
||
# 提取各尺码的父节点
|
||
size_parents = []
|
||
for cluster in clusters:
|
||
parents = [p for p in cluster if p["parent"] is None]
|
||
size_parents.append(parents)
|
||
|
||
# 预处理第一个尺码的多边形
|
||
base_pieces = size_parents[0]
|
||
base_polys = batch_normalize_polygons(base_pieces)
|
||
|
||
# 初始化匹配结果
|
||
match_result = {p["index"]: [p["index"]] for p in base_pieces}
|
||
|
||
# 依次匹配其他尺码
|
||
for size_idx in range(1, len(size_parents)):
|
||
compare_pieces = size_parents[size_idx]
|
||
compare_polys = batch_normalize_polygons(compare_pieces)
|
||
|
||
# 计算成本矩阵
|
||
cost_matrix = compute_matching_matrix(base_polys, compare_polys)
|
||
|
||
# 匈牙利算法匹配
|
||
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
||
|
||
for i, j in zip(row_ind, col_ind):
|
||
base_index = base_pieces[i]['index']
|
||
matched_index = compare_pieces[j]['index']
|
||
match_result[base_index].append(matched_index)
|
||
|
||
return match_result
|
||
|
||
def generate_single_piece_image(args: Tuple) -> Dict:
|
||
"""生成单个裁片图片(用于并行处理)"""
|
||
output, node, scale_factor, rotation, dpi, size_label, group_id, global_bounds, plt_to_cm = args
|
||
|
||
try:
|
||
nodes_to_draw = [node] + node['child']
|
||
|
||
# 绘制图像
|
||
img = output._draw_nodes(nodes_to_draw, scale_factor, show_id=False)
|
||
img_rgba = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
|
||
pil_img = Image.fromarray(img_rgba)
|
||
|
||
# 应用旋转
|
||
pil_img = apply_rotation_to_image(pil_img, rotation)
|
||
|
||
# 转换为 Base64
|
||
buffer = io.BytesIO()
|
||
pil_img.save(buffer, format='PNG', dpi=(dpi, dpi))
|
||
buffer.seek(0)
|
||
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
|
||
image_base64 = f"data:image/png;base64,{base64_str}"
|
||
|
||
# 计算坐标
|
||
coords = calculate_piece_coordinates(node["data"], global_bounds, plt_to_cm, rotation)
|
||
|
||
return {
|
||
"success": True,
|
||
"group_id": group_id,
|
||
"size_label": size_label,
|
||
"image_base64": image_base64,
|
||
"width_px": pil_img.width,
|
||
"height_px": pil_img.height,
|
||
**coords
|
||
}
|
||
except Exception as e:
|
||
return {"success": False, "error": str(e), "group_id": group_id, "size_label": size_label}
|
||
|
||
def get_match_result_between_files(standard_clusters, rotated_clusters) -> List[List[Dict]]:
|
||
"""计算两个文件间的匹配结果(优化版)"""
|
||
all_size_matches = []
|
||
|
||
for standard_cluster, rotated_cluster in zip(standard_clusters, rotated_clusters):
|
||
standard_parents = [p for p in standard_cluster if p["parent"] is None]
|
||
rotated_parents = [p for p in rotated_cluster if p["parent"] is None]
|
||
|
||
if len(standard_parents) != len(rotated_parents):
|
||
continue
|
||
|
||
n = len(standard_parents)
|
||
|
||
# 批量预处理
|
||
std_polys = batch_normalize_polygons(standard_parents)
|
||
rot_polys = batch_normalize_polygons(rotated_parents)
|
||
|
||
# 计算成本矩阵和角度矩阵
|
||
cost_matrix = np.zeros((n, n))
|
||
rotation_matrix = np.zeros((n, n))
|
||
|
||
for i in range(n):
|
||
for j in range(n):
|
||
dist, angle = fast_shape_similarity(std_polys[i], rot_polys[j])
|
||
cost_matrix[i, j] = dist
|
||
rotation_matrix[i, j] = angle
|
||
|
||
# 匈牙利算法匹配
|
||
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
||
|
||
matches = []
|
||
for i, j in zip(row_ind, col_ind):
|
||
matches.append({
|
||
"standard_id": standard_parents[i]["index"],
|
||
"rotated_id": rotated_parents[j]["index"],
|
||
"distance": round(cost_matrix[i, j], 4),
|
||
"angle": rotation_matrix[i, j]
|
||
})
|
||
|
||
all_size_matches.append(matches)
|
||
|
||
return all_size_matches
|
||
|
||
# ==================== API 接口 ====================
|
||
|
||
@router.post("/algorithm/process_plt", response_model=ProcessPltResponse)
|
||
async def process_plt(
|
||
file: UploadFile = File(..., description="标准 PLT 文件"),
|
||
rotated_file: Optional[UploadFile] = File(None, description="旋转后的 PLT 文件(可选)"),
|
||
size_labels: str = Form(..., description="尺码标签 JSON 数组"),
|
||
dpi: int = Form(150, description="输出图片分辨率(DPI)"),
|
||
rotation: int = Form(0, description="强制旋转角度 (0/90/-90/180)"),
|
||
current_username: str = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""PLT 裁片处理接口(优化版)"""
|
||
|
||
try:
|
||
# 1. 解析参数
|
||
try:
|
||
size_labels_list = json.loads(size_labels)
|
||
except:
|
||
raise HTTPException(status_code=400, detail="size_labels 格式错误")
|
||
|
||
size_num = len(size_labels_list)
|
||
scale_factor = dpi / 1016
|
||
plt_to_cm = 2.54 / 1016
|
||
|
||
# 2. 读取并解析 PLT 文件
|
||
print(f"[PLT API] 用户 {current_username} 正在处理: {file.filename}")
|
||
file_content = (await file.read()).decode('utf-8', errors='ignore')
|
||
output = parse_plt_file(file_content)
|
||
|
||
# 3. 获取尺码聚类
|
||
clusters = output.get_single_size_info(size_num=size_num)
|
||
|
||
# 4. 获取尺码间匹配关系(优化版)
|
||
match_result = get_size_matching(clusters, size_num)
|
||
|
||
# 5. 构建索引映射
|
||
all_nodes = {node["index"]: node for node in output.nodes}
|
||
|
||
# 6. 计算全局边界
|
||
from shapely.geometry import MultiPolygon as ShapelyMultiPolygon
|
||
all_polygons = [node["data"] for node in output.nodes]
|
||
global_bounds = ShapelyMultiPolygon(all_polygons).bounds
|
||
|
||
# 7. 并行生成图片和坐标数据
|
||
print(f"[PLT API] 开始并行生成图片...")
|
||
|
||
# 准备所有任务参数
|
||
tasks = []
|
||
for group_id, (base_index, matched_indices) in enumerate(match_result.items(), start=1):
|
||
for size_idx, piece_index in enumerate(matched_indices):
|
||
size_label = size_labels_list[size_idx]
|
||
node = all_nodes[piece_index]
|
||
tasks.append((
|
||
output, node, scale_factor, rotation, dpi,
|
||
size_label, group_id, global_bounds, plt_to_cm
|
||
))
|
||
|
||
# 使用线程池并行处理(CPU 核心数)
|
||
max_workers = min(os.cpu_count() or 4, 8) # 最多8线程
|
||
results = []
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
future_to_task = {executor.submit(generate_single_piece_image, task): task for task in tasks}
|
||
for future in as_completed(future_to_task):
|
||
result = future.result()
|
||
if result["success"]:
|
||
results.append(result)
|
||
|
||
# 上传图片到七牛云(如果启用)
|
||
image_urls = {} # key: (group_id, size_label) -> url
|
||
if qiniu_storage.enabled:
|
||
# 获取用户ID用于文件隔离
|
||
from app.models.user import User
|
||
user = db.query(User).filter(User.username == current_username).first()
|
||
user_id = user.id if user else None
|
||
|
||
print(f"[PLT API] 正在上传图片到七牛云,共 {len(results)} 张,用户: {current_username}(id={user_id})...")
|
||
upload_tasks = []
|
||
for r in results:
|
||
base64_data = r["image_base64"]
|
||
name = f"g{r['group_id']}_{r['size_label']}"
|
||
upload_tasks.append((base64_data, name))
|
||
|
||
# 上传到 plt/{日期}/u{用户ID}/ 目录
|
||
upload_results = qiniu_storage.upload_batch(upload_tasks, key_prefix="plt", user_id=user_id)
|
||
|
||
success_count = 0
|
||
for i, (success, url_or_base64, name) in enumerate(upload_results):
|
||
if success:
|
||
r = results[i]
|
||
image_urls[(r["group_id"], r["size_label"])] = url_or_base64
|
||
success_count += 1
|
||
|
||
print(f"[PLT API] 七牛云上传完成,成功 {success_count}/{len(results)} 张")
|
||
|
||
# 按 group_id 分组
|
||
groups_dict: Dict[int, List] = {}
|
||
for r in results:
|
||
gid = r["group_id"]
|
||
key = (gid, r["size_label"])
|
||
|
||
# 优先使用云存储 URL
|
||
if key in image_urls:
|
||
piece = PieceInfo(
|
||
size=r["size_label"],
|
||
image_url=image_urls[key],
|
||
image_base64=None,
|
||
width_px=r["width_px"],
|
||
height_px=r["height_px"],
|
||
center_x_cm=r["center_x_cm"],
|
||
center_y_cm=r["center_y_cm"],
|
||
left_cm=r["left_cm"],
|
||
top_cm=r["top_cm"],
|
||
width_cm=r["width_cm"],
|
||
height_cm=r["height_cm"]
|
||
)
|
||
else:
|
||
piece = PieceInfo(
|
||
size=r["size_label"],
|
||
image_base64=r["image_base64"],
|
||
image_url=None,
|
||
width_px=r["width_px"],
|
||
height_px=r["height_px"],
|
||
center_x_cm=r["center_x_cm"],
|
||
center_y_cm=r["center_y_cm"],
|
||
left_cm=r["left_cm"],
|
||
top_cm=r["top_cm"],
|
||
width_cm=r["width_cm"],
|
||
height_cm=r["height_cm"]
|
||
)
|
||
|
||
if gid not in groups_dict:
|
||
groups_dict[gid] = []
|
||
groups_dict[gid].append(piece)
|
||
|
||
# 构建 groups 列表
|
||
groups_list = []
|
||
for gid, pieces in groups_dict.items():
|
||
# 按尺码排序
|
||
sorted_pieces = sorted(pieces, key=lambda p: size_labels_list.index(p.size) if p.size in size_labels_list else 999)
|
||
# 计算该组的面积(用第一个裁片的面积作为代表)
|
||
area = sorted_pieces[0].width_cm * sorted_pieces[0].height_cm if sorted_pieces else 0
|
||
groups_list.append((gid, sorted_pieces, area))
|
||
|
||
# 按面积从大到小排序,重新分配 group_id
|
||
groups_list.sort(key=lambda x: -x[2]) # 面积降序
|
||
groups = [
|
||
GroupInfo(group_id=new_id, pieces=pieces)
|
||
for new_id, (_, pieces, _) in enumerate(groups_list, start=1)
|
||
]
|
||
|
||
print(f"[PLT API] 图片生成完成,共 {len(results)} 张")
|
||
|
||
# 8. 处理双文件匹配(可选)
|
||
match_analysis = None
|
||
if rotated_file:
|
||
print(f"[PLT API] 双文件匹配: {file.filename} vs {rotated_file.filename}")
|
||
rotated_content = (await rotated_file.read()).decode('utf-8', errors='ignore')
|
||
rotated_output = parse_plt_file(rotated_content)
|
||
rotated_clusters = rotated_output.get_single_size_info(size_num=size_num)
|
||
|
||
all_matches = get_match_result_between_files(clusters, rotated_clusters)
|
||
|
||
match_analysis = [
|
||
SizeMatchInfo(size=size_labels_list[i], matches=[MatchInfo(**m) for m in matches])
|
||
for i, matches in enumerate(all_matches)
|
||
]
|
||
|
||
# 9. 返回结果
|
||
print(f"[PLT API] 处理完成,共 {len(groups)} 组裁片")
|
||
return ProcessPltResponse(
|
||
success=True,
|
||
total_groups=len(groups),
|
||
groups=groups,
|
||
match_analysis=match_analysis
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
print(f"[PLT API] 处理失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|