Files
DP/Server/app/api/v1/algorithm.py

507 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
PLT 裁片处理接口 (Photoshop 插件专用)
功能:解析 PLT 文件,生成裁片图片和坐标信息
优化版:提升匹配计算速度 + 并行图片生成
"""
import os
import io
import base64
import json
import numpy as np
import cv2
from typing import List, Optional, Dict, Tuple
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from PIL import Image
from shapely import affinity
from scipy.optimize import linear_sum_assignment
from concurrent.futures import ThreadPoolExecutor, as_completed
import asyncio
# 导入自定义模块
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from pltreader import PltReader
from app.db import get_db
from app.core.security import get_current_user
from app.core.qiniu_storage import qiniu_storage
router = APIRouter()
# ==================== 数据模型 ====================
class PieceInfo(BaseModel):
"""单个裁片信息"""
size: str
image_base64: Optional[str] = None # base64 图片(未启用云存储时使用)
image_url: Optional[str] = None # 云存储 URL启用七牛云时使用
width_px: int
height_px: int
width_cm: float
height_cm: float
center_x_cm: float
center_y_cm: float
left_cm: float
top_cm: float
class GroupInfo(BaseModel):
"""裁片分组信息"""
group_id: int
pieces: List[PieceInfo]
class MatchInfo(BaseModel):
"""双文件匹配结果"""
standard_id: int
rotated_id: int
distance: float
angle: float
class SizeMatchInfo(BaseModel):
"""每个尺码的匹配结果"""
size: str
matches: List[MatchInfo]
class ProcessPltResponse(BaseModel):
"""API 响应"""
success: bool
total_groups: int
groups: List[GroupInfo]
match_analysis: Optional[List[SizeMatchInfo]] = None
# ==================== 优化的辅助函数 ====================
def parse_plt_file(file_content: str, tolerance: int = 10):
"""解析 PLT 文件内容"""
reader = PltReader(io.StringIO(file_content))
return reader.get_output(tolerance=tolerance)
def normalize_polygon(poly):
"""
预处理多边形:质心对齐 + 面积归一化
返回: (归一化多边形, 原始面积, 原始周长)
"""
area = poly.area
perimeter = poly.length
if area <= 0:
return poly, 0, 0
# 质心对齐
poly = affinity.translate(poly, xoff=-poly.centroid.x, yoff=-poly.centroid.y)
# 缩放到统一面积
target_area = 10000
scale = (target_area / area) ** 0.5
poly = affinity.scale(poly, xfact=scale, yfact=scale, origin=(0, 0))
return poly, area, perimeter
def fast_shape_similarity(poly1_data: Tuple, poly2_data: Tuple) -> Tuple[float, int]:
"""
快速形状相似度计算
poly_data: (归一化多边形, 原始面积, 原始周长)
返回: (距离, 最佳角度)
"""
poly1, area1, peri1 = poly1_data
poly2, area2, peri2 = poly2_data
# 面积比初筛(面积差异太大的直接跳过)
if area1 > 0 and area2 > 0:
area_ratio = min(area1, area2) / max(area1, area2)
if area_ratio < 0.5: # 面积差异超过50%
return float('inf'), 0
# 只尝试 0° 和 180°服装裁片通常只需要这两个角度
dist_min = float('inf')
best_angle = 0
for angle in [0, 180]:
try:
rotated = affinity.rotate(poly2, angle, origin='centroid') if angle != 0 else poly2
# 使用 symmetric_difference 面积作为距离度量(比 hausdorff 快很多)
diff_area = poly1.symmetric_difference(rotated).area
dist = diff_area / max(poly1.area, 1)
except:
dist = float('inf')
if dist < dist_min:
dist_min = dist
best_angle = angle
return dist_min, best_angle
def batch_normalize_polygons(pieces: List[Dict]) -> List[Tuple]:
"""批量预处理多边形"""
return [normalize_polygon(p["data"]) for p in pieces]
def compute_matching_matrix(base_polys: List[Tuple], compare_polys: List[Tuple]) -> np.ndarray:
"""计算匹配成本矩阵"""
n, m = len(base_polys), len(compare_polys)
cost_matrix = np.zeros((n, m))
for i in range(n):
for j in range(m):
dist, _ = fast_shape_similarity(base_polys[i], compare_polys[j])
cost_matrix[i, j] = dist
return cost_matrix
def apply_rotation_to_image(pil_img: Image.Image, rotation: int) -> Image.Image:
"""应用旋转变换"""
if rotation == 90:
return pil_img.rotate(-90, expand=True)
elif rotation == -90:
return pil_img.rotate(90, expand=True)
elif rotation == 180:
return pil_img.rotate(180, expand=True)
return pil_img
def calculate_piece_coordinates(polygon, bounds, plt_to_cm, rotation=0) -> Dict:
"""计算裁片坐标(厘米单位)"""
min_x, min_y, max_x, max_y = bounds
centroid = polygon.centroid
orig_center_x = (centroid.x - min_x) * plt_to_cm
orig_center_y = (centroid.y - min_y) * plt_to_cm
piece_bounds = polygon.bounds
orig_piece_width = (piece_bounds[2] - piece_bounds[0]) * plt_to_cm
orig_piece_height = (piece_bounds[3] - piece_bounds[1]) * plt_to_cm
orig_canvas_width = (max_x - min_x) * plt_to_cm
orig_canvas_height = (max_y - min_y) * plt_to_cm
if rotation == 90:
center_x = orig_center_y
center_y = orig_canvas_width - orig_center_x
piece_width, piece_height = orig_piece_height, orig_piece_width
elif rotation == -90:
center_x = orig_canvas_height - orig_center_y
center_y = orig_center_x
piece_width, piece_height = orig_piece_height, orig_piece_width
elif rotation == 180:
center_x = orig_canvas_width - orig_center_x
center_y = orig_canvas_height - orig_center_y
piece_width, piece_height = orig_piece_width, orig_piece_height
else:
center_x, center_y = orig_center_x, orig_center_y
piece_width, piece_height = orig_piece_width, orig_piece_height
left_cm = center_x - piece_width / 2
top_cm = center_y - piece_height / 2
return {
"center_x_cm": round(center_x, 2),
"center_y_cm": round(center_y, 2),
"left_cm": round(left_cm, 2),
"top_cm": round(top_cm, 2),
"width_cm": round(piece_width, 2),
"height_cm": round(piece_height, 2)
}
def get_size_matching(clusters: List, size_num: int) -> Dict[int, List[int]]:
"""
获取尺码间匹配关系(优化版)
返回: {base_index: [matched_indices_per_size]}
"""
# 提取各尺码的父节点
size_parents = []
for cluster in clusters:
parents = [p for p in cluster if p["parent"] is None]
size_parents.append(parents)
# 预处理第一个尺码的多边形
base_pieces = size_parents[0]
base_polys = batch_normalize_polygons(base_pieces)
# 初始化匹配结果
match_result = {p["index"]: [p["index"]] for p in base_pieces}
# 依次匹配其他尺码
for size_idx in range(1, len(size_parents)):
compare_pieces = size_parents[size_idx]
compare_polys = batch_normalize_polygons(compare_pieces)
# 计算成本矩阵
cost_matrix = compute_matching_matrix(base_polys, compare_polys)
# 匈牙利算法匹配
row_ind, col_ind = linear_sum_assignment(cost_matrix)
for i, j in zip(row_ind, col_ind):
base_index = base_pieces[i]['index']
matched_index = compare_pieces[j]['index']
match_result[base_index].append(matched_index)
return match_result
def generate_single_piece_image(args: Tuple) -> Dict:
"""生成单个裁片图片(用于并行处理)"""
output, node, scale_factor, rotation, dpi, size_label, group_id, global_bounds, plt_to_cm = args
try:
nodes_to_draw = [node] + node['child']
# 绘制图像
img = output._draw_nodes(nodes_to_draw, scale_factor, show_id=False)
img_rgba = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
pil_img = Image.fromarray(img_rgba)
# 应用旋转
pil_img = apply_rotation_to_image(pil_img, rotation)
# 转换为 Base64
buffer = io.BytesIO()
pil_img.save(buffer, format='PNG', dpi=(dpi, dpi))
buffer.seek(0)
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
image_base64 = f"data:image/png;base64,{base64_str}"
# 计算坐标
coords = calculate_piece_coordinates(node["data"], global_bounds, plt_to_cm, rotation)
return {
"success": True,
"group_id": group_id,
"size_label": size_label,
"image_base64": image_base64,
"width_px": pil_img.width,
"height_px": pil_img.height,
**coords
}
except Exception as e:
return {"success": False, "error": str(e), "group_id": group_id, "size_label": size_label}
def get_match_result_between_files(standard_clusters, rotated_clusters) -> List[List[Dict]]:
"""计算两个文件间的匹配结果(优化版)"""
all_size_matches = []
for standard_cluster, rotated_cluster in zip(standard_clusters, rotated_clusters):
standard_parents = [p for p in standard_cluster if p["parent"] is None]
rotated_parents = [p for p in rotated_cluster if p["parent"] is None]
if len(standard_parents) != len(rotated_parents):
continue
n = len(standard_parents)
# 批量预处理
std_polys = batch_normalize_polygons(standard_parents)
rot_polys = batch_normalize_polygons(rotated_parents)
# 计算成本矩阵和角度矩阵
cost_matrix = np.zeros((n, n))
rotation_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
dist, angle = fast_shape_similarity(std_polys[i], rot_polys[j])
cost_matrix[i, j] = dist
rotation_matrix[i, j] = angle
# 匈牙利算法匹配
row_ind, col_ind = linear_sum_assignment(cost_matrix)
matches = []
for i, j in zip(row_ind, col_ind):
matches.append({
"standard_id": standard_parents[i]["index"],
"rotated_id": rotated_parents[j]["index"],
"distance": round(cost_matrix[i, j], 4),
"angle": rotation_matrix[i, j]
})
all_size_matches.append(matches)
return all_size_matches
# ==================== API 接口 ====================
@router.post("/algorithm/process_plt", response_model=ProcessPltResponse)
async def process_plt(
file: UploadFile = File(..., description="标准 PLT 文件"),
rotated_file: Optional[UploadFile] = File(None, description="旋转后的 PLT 文件(可选)"),
size_labels: str = Form(..., description="尺码标签 JSON 数组"),
dpi: int = Form(150, description="输出图片分辨率DPI"),
rotation: int = Form(0, description="强制旋转角度 (0/90/-90/180)"),
current_username: str = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""PLT 裁片处理接口(优化版)"""
try:
# 1. 解析参数
try:
size_labels_list = json.loads(size_labels)
except:
raise HTTPException(status_code=400, detail="size_labels 格式错误")
size_num = len(size_labels_list)
scale_factor = dpi / 1016
plt_to_cm = 2.54 / 1016
# 2. 读取并解析 PLT 文件
print(f"[PLT API] 用户 {current_username} 正在处理: {file.filename}")
file_content = (await file.read()).decode('utf-8', errors='ignore')
output = parse_plt_file(file_content)
# 3. 获取尺码聚类
clusters = output.get_single_size_info(size_num=size_num)
# 4. 获取尺码间匹配关系(优化版)
match_result = get_size_matching(clusters, size_num)
# 5. 构建索引映射
all_nodes = {node["index"]: node for node in output.nodes}
# 6. 计算全局边界
from shapely.geometry import MultiPolygon as ShapelyMultiPolygon
all_polygons = [node["data"] for node in output.nodes]
global_bounds = ShapelyMultiPolygon(all_polygons).bounds
# 7. 并行生成图片和坐标数据
print(f"[PLT API] 开始并行生成图片...")
# 准备所有任务参数
tasks = []
for group_id, (base_index, matched_indices) in enumerate(match_result.items(), start=1):
for size_idx, piece_index in enumerate(matched_indices):
size_label = size_labels_list[size_idx]
node = all_nodes[piece_index]
tasks.append((
output, node, scale_factor, rotation, dpi,
size_label, group_id, global_bounds, plt_to_cm
))
# 使用线程池并行处理CPU 核心数)
max_workers = min(os.cpu_count() or 4, 8) # 最多8线程
results = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_task = {executor.submit(generate_single_piece_image, task): task for task in tasks}
for future in as_completed(future_to_task):
result = future.result()
if result["success"]:
results.append(result)
# 上传图片到七牛云(如果启用)
image_urls = {} # key: (group_id, size_label) -> url
if qiniu_storage.enabled:
# 获取用户ID用于文件隔离
from app.models.user import User
user = db.query(User).filter(User.username == current_username).first()
user_id = user.id if user else None
print(f"[PLT API] 正在上传图片到七牛云,共 {len(results)} 张,用户: {current_username}(id={user_id})...")
upload_tasks = []
for r in results:
base64_data = r["image_base64"]
name = f"g{r['group_id']}_{r['size_label']}"
upload_tasks.append((base64_data, name))
# 上传到 plt/{日期}/u{用户ID}/ 目录
upload_results = qiniu_storage.upload_batch(upload_tasks, key_prefix="plt", user_id=user_id)
success_count = 0
for i, (success, url_or_base64, name) in enumerate(upload_results):
if success:
r = results[i]
image_urls[(r["group_id"], r["size_label"])] = url_or_base64
success_count += 1
print(f"[PLT API] 七牛云上传完成,成功 {success_count}/{len(results)}")
# 按 group_id 分组
groups_dict: Dict[int, List] = {}
for r in results:
gid = r["group_id"]
key = (gid, r["size_label"])
# 优先使用云存储 URL
if key in image_urls:
piece = PieceInfo(
size=r["size_label"],
image_url=image_urls[key],
image_base64=None,
width_px=r["width_px"],
height_px=r["height_px"],
center_x_cm=r["center_x_cm"],
center_y_cm=r["center_y_cm"],
left_cm=r["left_cm"],
top_cm=r["top_cm"],
width_cm=r["width_cm"],
height_cm=r["height_cm"]
)
else:
piece = PieceInfo(
size=r["size_label"],
image_base64=r["image_base64"],
image_url=None,
width_px=r["width_px"],
height_px=r["height_px"],
center_x_cm=r["center_x_cm"],
center_y_cm=r["center_y_cm"],
left_cm=r["left_cm"],
top_cm=r["top_cm"],
width_cm=r["width_cm"],
height_cm=r["height_cm"]
)
if gid not in groups_dict:
groups_dict[gid] = []
groups_dict[gid].append(piece)
# 构建 groups 列表
groups_list = []
for gid, pieces in groups_dict.items():
# 按尺码排序
sorted_pieces = sorted(pieces, key=lambda p: size_labels_list.index(p.size) if p.size in size_labels_list else 999)
# 计算该组的面积(用第一个裁片的面积作为代表)
area = sorted_pieces[0].width_cm * sorted_pieces[0].height_cm if sorted_pieces else 0
groups_list.append((gid, sorted_pieces, area))
# 按面积从大到小排序,重新分配 group_id
groups_list.sort(key=lambda x: -x[2]) # 面积降序
groups = [
GroupInfo(group_id=new_id, pieces=pieces)
for new_id, (_, pieces, _) in enumerate(groups_list, start=1)
]
print(f"[PLT API] 图片生成完成,共 {len(results)}")
# 8. 处理双文件匹配(可选)
match_analysis = None
if rotated_file:
print(f"[PLT API] 双文件匹配: {file.filename} vs {rotated_file.filename}")
rotated_content = (await rotated_file.read()).decode('utf-8', errors='ignore')
rotated_output = parse_plt_file(rotated_content)
rotated_clusters = rotated_output.get_single_size_info(size_num=size_num)
all_matches = get_match_result_between_files(clusters, rotated_clusters)
match_analysis = [
SizeMatchInfo(size=size_labels_list[i], matches=[MatchInfo(**m) for m in matches])
for i, matches in enumerate(all_matches)
]
# 9. 返回结果
print(f"[PLT API] 处理完成,共 {len(groups)} 组裁片")
return ProcessPltResponse(
success=True,
total_groups=len(groups),
groups=groups,
match_analysis=match_analysis
)
except HTTPException:
raise
except Exception as e:
print(f"[PLT API] 处理失败: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")