Files
DP/PltService/main.py

262 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
PLT 裁片处理微服务
独立部署到阿里云 SAE
"""
import os
import io
import base64
import json
from typing import List, Optional
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import numpy as np
import cv2
from PIL import Image
from shapely import affinity
from scipy.optimize import linear_sum_assignment
from pltreader import PltReader
app = FastAPI(title="PLT Processing Service")
# CORS 配置
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==================== 数据模型 ====================
class PieceInfo(BaseModel):
"""单个裁片信息"""
size: str
image_base64: str
width_px: int
height_px: int
width_cm: float
height_cm: float
center_x_cm: float
center_y_cm: float
left_cm: float
top_cm: float
class GroupInfo(BaseModel):
"""裁片分组信息"""
group_id: int
pieces: List[PieceInfo]
class ProcessPltResponse(BaseModel):
"""API 响应"""
success: bool
total_groups: int
groups: List[GroupInfo]
# ==================== 辅助函数 ====================
def parse_plt_file(file_content: str, tolerance: int = 10):
"""解析 PLT 文件内容"""
reader = PltReader(io.StringIO(file_content))
output = reader.get_output(tolerance=tolerance)
return output
def apply_rotation_to_image(pil_img: Image.Image, rotation: int) -> Image.Image:
"""应用旋转变换"""
if rotation == 90:
return pil_img.rotate(-90, expand=True)
elif rotation == -90:
return pil_img.rotate(90, expand=True)
elif rotation == 180:
return pil_img.rotate(180, expand=True)
return pil_img
def calculate_piece_coordinates(polygon, bounds, plt_to_cm, rotation=0):
"""计算裁片坐标(厘米单位)"""
min_x, min_y, max_x, max_y = bounds
centroid = polygon.centroid
orig_center_x = (centroid.x - min_x) * plt_to_cm
orig_center_y = (centroid.y - min_y) * plt_to_cm
piece_bounds = polygon.bounds
orig_piece_width = (piece_bounds[2] - piece_bounds[0]) * plt_to_cm
orig_piece_height = (piece_bounds[3] - piece_bounds[1]) * plt_to_cm
orig_canvas_width = (max_x - min_x) * plt_to_cm
orig_canvas_height = (max_y - min_y) * plt_to_cm
if rotation == 90:
center_x = orig_center_y
center_y = orig_canvas_width - orig_center_x
piece_width, piece_height = orig_piece_height, orig_piece_width
elif rotation == -90:
center_x = orig_canvas_height - orig_center_y
center_y = orig_center_x
piece_width, piece_height = orig_piece_height, orig_piece_width
elif rotation == 180:
center_x = orig_canvas_width - orig_center_x
center_y = orig_canvas_height - orig_center_y
piece_width, piece_height = orig_piece_width, orig_piece_height
else:
center_x, center_y = orig_center_x, orig_center_y
piece_width, piece_height = orig_piece_width, orig_piece_height
left_cm = center_x - piece_width / 2
top_cm = center_y - piece_height / 2
return {
"center_x_cm": round(center_x, 2),
"center_y_cm": round(center_y, 2),
"left_cm": round(left_cm, 2),
"top_cm": round(top_cm, 2),
"width_cm": round(piece_width, 2),
"height_cm": round(piece_height, 2)
}
# ==================== API 接口 ====================
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "service": "plt-processor"}
@app.post("/process", response_model=ProcessPltResponse)
async def process_plt(
file: UploadFile = File(..., description="PLT 文件"),
size_labels: str = Form(..., description="尺码标签 JSON 数组"),
dpi: int = Form(150, description="输出图片分辨率DPI"),
rotation: int = Form(0, description="强制旋转角度 (0/90/-90/180)")
):
"""
PLT 裁片处理接口
"""
try:
# 1. 解析参数
try:
size_labels_list = json.loads(size_labels)
except:
raise HTTPException(status_code=400, detail="size_labels 格式错误")
size_num = len(size_labels_list)
scale_factor = dpi / 1016
plt_to_cm = 2.54 / 1016
# 2. 读取文件
file_content = (await file.read()).decode('utf-8', errors='ignore')
# 3. 解析 PLT
print(f"[PLT] 正在处理文件: {file.filename}")
output = parse_plt_file(file_content)
# 4. 获取尺码聚类
clusters = output.get_single_size_info(size_num=size_num)
# 5. 建立匹配关系
base_cluster = [piece for piece in clusters[0] if piece["parent"] is None]
match_result = {piece["index"]: [piece["index"]] for piece in base_cluster}
for size_index in range(1, len(clusters)):
compare_cluster = [piece for piece in clusters[size_index] if piece["parent"] is None]
cost_matrix = np.zeros((len(base_cluster), len(compare_cluster)))
for i in range(len(base_cluster)):
for j in range(len(compare_cluster)):
poly1 = base_cluster[i]["data"]
poly2 = compare_cluster[j]["data"]
poly1 = affinity.translate(poly1, xoff=-poly1.centroid.x, yoff=-poly1.centroid.y)
poly2 = affinity.translate(poly2, xoff=-poly2.centroid.x, yoff=-poly2.centroid.y)
target_area = 10000
scale1 = (target_area / poly1.area) ** 0.5
scale2 = (target_area / poly2.area) ** 0.5
poly1 = affinity.scale(poly1, xfact=scale1, yfact=scale1, origin=(0, 0))
poly2 = affinity.scale(poly2, xfact=scale2, yfact=scale2, origin=(0, 0))
dist_min = float('inf')
for angle in [0, 90, 180, 270]:
rotated = affinity.rotate(poly2, angle, origin='centroid')
dist = poly1.hausdorff_distance(rotated)
if dist < dist_min:
dist_min = dist
cost_matrix[i, j] = dist_min
row_ind, col_ind = linear_sum_assignment(cost_matrix)
for i, j in zip(row_ind, col_ind):
base_index = base_cluster[i]['index']
matched_index = compare_cluster[j]['index']
match_result[base_index].append(matched_index)
# 6. 生成图片
all_nodes = {node["index"]: node for node in output.nodes}
from shapely.geometry import MultiPolygon as ShapelyMultiPolygon
all_polygons = [node["data"] for node in output.nodes]
global_bounds = ShapelyMultiPolygon(all_polygons).bounds
groups = []
for group_id, (base_index, matched_indices) in enumerate(match_result.items(), start=1):
pieces = []
for size_idx, piece_index in enumerate(matched_indices):
size_label = size_labels_list[size_idx]
node = all_nodes[piece_index]
nodes_to_draw = [node] + node['child']
img = output._draw_nodes(nodes_to_draw, scale_factor, show_id=False)
img_rgba = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
pil_img = Image.fromarray(img_rgba)
pil_img = apply_rotation_to_image(pil_img, rotation)
buffer = io.BytesIO()
pil_img.save(buffer, format='PNG', dpi=(dpi, dpi))
buffer.seek(0)
base64_str = base64.b64encode(buffer.read()).decode('utf-8')
image_base64 = f"data:image/png;base64,{base64_str}"
coords = calculate_piece_coordinates(
node["data"],
global_bounds,
plt_to_cm,
rotation
)
piece_info = PieceInfo(
size=size_label,
image_base64=image_base64,
width_px=pil_img.width,
height_px=pil_img.height,
**coords
)
pieces.append(piece_info)
groups.append(GroupInfo(group_id=group_id, pieces=pieces))
print(f"[PLT] 处理完成,共 {len(groups)} 组裁片")
return ProcessPltResponse(
success=True,
total_groups=len(groups),
groups=groups
)
except Exception as e:
print(f"[PLT] 处理失败: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 8080))
uvicorn.run(app, host="0.0.0.0", port=port)