262 lines
9.1 KiB
Python
262 lines
9.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
PLT 裁片处理微服务
|
||
独立部署到阿里云 SAE
|
||
"""
|
||
|
||
import os
|
||
import io
|
||
import base64
|
||
import json
|
||
from typing import List, Optional
|
||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image
|
||
from shapely import affinity
|
||
from scipy.optimize import linear_sum_assignment
|
||
|
||
from pltreader import PltReader
|
||
|
||
app = FastAPI(title="PLT Processing Service")
|
||
|
||
# CORS 配置
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class PieceInfo(BaseModel):
|
||
"""单个裁片信息"""
|
||
size: str
|
||
image_base64: str
|
||
width_px: int
|
||
height_px: int
|
||
width_cm: float
|
||
height_cm: float
|
||
center_x_cm: float
|
||
center_y_cm: float
|
||
left_cm: float
|
||
top_cm: float
|
||
|
||
class GroupInfo(BaseModel):
|
||
"""裁片分组信息"""
|
||
group_id: int
|
||
pieces: List[PieceInfo]
|
||
|
||
class ProcessPltResponse(BaseModel):
|
||
"""API 响应"""
|
||
success: bool
|
||
total_groups: int
|
||
groups: List[GroupInfo]
|
||
|
||
# ==================== 辅助函数 ====================
|
||
|
||
def parse_plt_file(file_content: str, tolerance: int = 10):
|
||
"""解析 PLT 文件内容"""
|
||
reader = PltReader(io.StringIO(file_content))
|
||
output = reader.get_output(tolerance=tolerance)
|
||
return output
|
||
|
||
def apply_rotation_to_image(pil_img: Image.Image, rotation: int) -> Image.Image:
|
||
"""应用旋转变换"""
|
||
if rotation == 90:
|
||
return pil_img.rotate(-90, expand=True)
|
||
elif rotation == -90:
|
||
return pil_img.rotate(90, expand=True)
|
||
elif rotation == 180:
|
||
return pil_img.rotate(180, expand=True)
|
||
return pil_img
|
||
|
||
def calculate_piece_coordinates(polygon, bounds, plt_to_cm, rotation=0):
|
||
"""计算裁片坐标(厘米单位)"""
|
||
min_x, min_y, max_x, max_y = bounds
|
||
|
||
centroid = polygon.centroid
|
||
orig_center_x = (centroid.x - min_x) * plt_to_cm
|
||
orig_center_y = (centroid.y - min_y) * plt_to_cm
|
||
|
||
piece_bounds = polygon.bounds
|
||
orig_piece_width = (piece_bounds[2] - piece_bounds[0]) * plt_to_cm
|
||
orig_piece_height = (piece_bounds[3] - piece_bounds[1]) * plt_to_cm
|
||
|
||
orig_canvas_width = (max_x - min_x) * plt_to_cm
|
||
orig_canvas_height = (max_y - min_y) * plt_to_cm
|
||
|
||
if rotation == 90:
|
||
center_x = orig_center_y
|
||
center_y = orig_canvas_width - orig_center_x
|
||
piece_width, piece_height = orig_piece_height, orig_piece_width
|
||
elif rotation == -90:
|
||
center_x = orig_canvas_height - orig_center_y
|
||
center_y = orig_center_x
|
||
piece_width, piece_height = orig_piece_height, orig_piece_width
|
||
elif rotation == 180:
|
||
center_x = orig_canvas_width - orig_center_x
|
||
center_y = orig_canvas_height - orig_center_y
|
||
piece_width, piece_height = orig_piece_width, orig_piece_height
|
||
else:
|
||
center_x, center_y = orig_center_x, orig_center_y
|
||
piece_width, piece_height = orig_piece_width, orig_piece_height
|
||
|
||
left_cm = center_x - piece_width / 2
|
||
top_cm = center_y - piece_height / 2
|
||
|
||
return {
|
||
"center_x_cm": round(center_x, 2),
|
||
"center_y_cm": round(center_y, 2),
|
||
"left_cm": round(left_cm, 2),
|
||
"top_cm": round(top_cm, 2),
|
||
"width_cm": round(piece_width, 2),
|
||
"height_cm": round(piece_height, 2)
|
||
}
|
||
|
||
# ==================== API 接口 ====================
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return {"status": "healthy", "service": "plt-processor"}
|
||
|
||
@app.post("/process", response_model=ProcessPltResponse)
|
||
async def process_plt(
|
||
file: UploadFile = File(..., description="PLT 文件"),
|
||
size_labels: str = Form(..., description="尺码标签 JSON 数组"),
|
||
dpi: int = Form(150, description="输出图片分辨率(DPI)"),
|
||
rotation: int = Form(0, description="强制旋转角度 (0/90/-90/180)")
|
||
):
|
||
"""
|
||
PLT 裁片处理接口
|
||
"""
|
||
try:
|
||
# 1. 解析参数
|
||
try:
|
||
size_labels_list = json.loads(size_labels)
|
||
except:
|
||
raise HTTPException(status_code=400, detail="size_labels 格式错误")
|
||
|
||
size_num = len(size_labels_list)
|
||
scale_factor = dpi / 1016
|
||
plt_to_cm = 2.54 / 1016
|
||
|
||
# 2. 读取文件
|
||
file_content = (await file.read()).decode('utf-8', errors='ignore')
|
||
|
||
# 3. 解析 PLT
|
||
print(f"[PLT] 正在处理文件: {file.filename}")
|
||
output = parse_plt_file(file_content)
|
||
|
||
# 4. 获取尺码聚类
|
||
clusters = output.get_single_size_info(size_num=size_num)
|
||
|
||
# 5. 建立匹配关系
|
||
base_cluster = [piece for piece in clusters[0] if piece["parent"] is None]
|
||
match_result = {piece["index"]: [piece["index"]] for piece in base_cluster}
|
||
|
||
for size_index in range(1, len(clusters)):
|
||
compare_cluster = [piece for piece in clusters[size_index] if piece["parent"] is None]
|
||
|
||
cost_matrix = np.zeros((len(base_cluster), len(compare_cluster)))
|
||
|
||
for i in range(len(base_cluster)):
|
||
for j in range(len(compare_cluster)):
|
||
poly1 = base_cluster[i]["data"]
|
||
poly2 = compare_cluster[j]["data"]
|
||
|
||
poly1 = affinity.translate(poly1, xoff=-poly1.centroid.x, yoff=-poly1.centroid.y)
|
||
poly2 = affinity.translate(poly2, xoff=-poly2.centroid.x, yoff=-poly2.centroid.y)
|
||
|
||
target_area = 10000
|
||
scale1 = (target_area / poly1.area) ** 0.5
|
||
scale2 = (target_area / poly2.area) ** 0.5
|
||
poly1 = affinity.scale(poly1, xfact=scale1, yfact=scale1, origin=(0, 0))
|
||
poly2 = affinity.scale(poly2, xfact=scale2, yfact=scale2, origin=(0, 0))
|
||
|
||
dist_min = float('inf')
|
||
for angle in [0, 90, 180, 270]:
|
||
rotated = affinity.rotate(poly2, angle, origin='centroid')
|
||
dist = poly1.hausdorff_distance(rotated)
|
||
if dist < dist_min:
|
||
dist_min = dist
|
||
|
||
cost_matrix[i, j] = dist_min
|
||
|
||
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
||
for i, j in zip(row_ind, col_ind):
|
||
base_index = base_cluster[i]['index']
|
||
matched_index = compare_cluster[j]['index']
|
||
match_result[base_index].append(matched_index)
|
||
|
||
# 6. 生成图片
|
||
all_nodes = {node["index"]: node for node in output.nodes}
|
||
|
||
from shapely.geometry import MultiPolygon as ShapelyMultiPolygon
|
||
all_polygons = [node["data"] for node in output.nodes]
|
||
global_bounds = ShapelyMultiPolygon(all_polygons).bounds
|
||
|
||
groups = []
|
||
|
||
for group_id, (base_index, matched_indices) in enumerate(match_result.items(), start=1):
|
||
pieces = []
|
||
|
||
for size_idx, piece_index in enumerate(matched_indices):
|
||
size_label = size_labels_list[size_idx]
|
||
|
||
node = all_nodes[piece_index]
|
||
nodes_to_draw = [node] + node['child']
|
||
|
||
img = output._draw_nodes(nodes_to_draw, scale_factor, show_id=False)
|
||
|
||
img_rgba = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
|
||
pil_img = Image.fromarray(img_rgba)
|
||
pil_img = apply_rotation_to_image(pil_img, rotation)
|
||
|
||
buffer = io.BytesIO()
|
||
pil_img.save(buffer, format='PNG', dpi=(dpi, dpi))
|
||
buffer.seek(0)
|
||
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
|
||
image_base64 = f"data:image/png;base64,{base64_str}"
|
||
|
||
coords = calculate_piece_coordinates(
|
||
node["data"],
|
||
global_bounds,
|
||
plt_to_cm,
|
||
rotation
|
||
)
|
||
|
||
piece_info = PieceInfo(
|
||
size=size_label,
|
||
image_base64=image_base64,
|
||
width_px=pil_img.width,
|
||
height_px=pil_img.height,
|
||
**coords
|
||
)
|
||
pieces.append(piece_info)
|
||
|
||
groups.append(GroupInfo(group_id=group_id, pieces=pieces))
|
||
|
||
print(f"[PLT] 处理完成,共 {len(groups)} 组裁片")
|
||
return ProcessPltResponse(
|
||
success=True,
|
||
total_groups=len(groups),
|
||
groups=groups
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"[PLT] 处理失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
port = int(os.getenv("PORT", 8080))
|
||
uvicorn.run(app, host="0.0.0.0", port=port)
|