273 lines
9.5 KiB
Python
273 lines
9.5 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, Header, Query
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import desc, or_
|
||
from typing import Optional
|
||
import os
|
||
import uuid
|
||
import shutil
|
||
from datetime import datetime
|
||
from PIL import Image
|
||
from app.core.database import get_db
|
||
from app.models.work import Work
|
||
from app.models.user import User
|
||
from app.core.security import decode_access_token
|
||
|
||
router = APIRouter(prefix="/upload", tags=["上传"])
|
||
|
||
# 上传配置
|
||
UPLOAD_BASE_DIR = "/app/uploads"
|
||
ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||
|
||
# 图绘域名
|
||
TUHUI_DOMAIN = "https://tuhui.cloud"
|
||
|
||
|
||
def get_current_user(authorization: str = Header(None), db: Session = Depends(get_db)):
|
||
"""获取当前登录用户"""
|
||
if not authorization or not authorization.startswith("Bearer "):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="未登录,请先登录"
|
||
)
|
||
|
||
token = authorization.replace("Bearer ", "")
|
||
payload = decode_access_token(token)
|
||
if not payload:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Token 无效或已过期"
|
||
)
|
||
|
||
user_id = int(payload.get("sub"))
|
||
user = db.query(User).filter(User.id == user_id).first()
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户不存在"
|
||
)
|
||
|
||
return user
|
||
|
||
|
||
def _user_designer_aliases(user: User) -> list[str]:
|
||
aliases = []
|
||
for value in (getattr(user, "nickname", None), getattr(user, "phone", None)):
|
||
cleaned = str(value or "").strip()
|
||
if cleaned and cleaned not in aliases:
|
||
aliases.append(cleaned)
|
||
return aliases
|
||
|
||
|
||
def _generate_fallback_title(category: str, designer_name: str, unique_id: str) -> str:
|
||
category_name = str(category or "设计素材").strip() or "设计素材"
|
||
designer = str(designer_name or "").strip()
|
||
suffix = unique_id.replace("-", "")[:6]
|
||
if designer:
|
||
return f"{category_name}_{designer}_{suffix}"
|
||
return f"{category_name}_自动生成_{suffix}"
|
||
|
||
|
||
def generate_thumbnail(image_path: str, thumb_path: str, size=(400, 400)):
|
||
"""生成缩略图 - 修复透明 PNG 问题"""
|
||
with Image.open(image_path) as img:
|
||
# 修复 Bug #1: 透明 PNG 转 RGB 后再保存为 JPEG
|
||
if img.mode in ('RGBA', 'LA', 'P'):
|
||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||
if img.mode == 'P':
|
||
img = img.convert('RGBA')
|
||
if img.mode in ('RGBA', 'LA'):
|
||
background.paste(img, mask=img.split()[-1])
|
||
img = background
|
||
elif img.mode != 'RGB':
|
||
img = img.convert('RGB')
|
||
|
||
img.thumbnail(size, Image.Resampling.LANCZOS)
|
||
img.save(thumb_path, quality=85)
|
||
|
||
|
||
def add_watermark(image_path: str, watermarked_path: str, watermark_text: str = "图绘"):
|
||
"""添加水印 - 修复透明 PNG 问题"""
|
||
with Image.open(image_path) as img:
|
||
# 修复 Bug #1: 透明 PNG 转 RGB 后再保存为 JPEG
|
||
if img.mode in ('RGBA', 'LA', 'P'):
|
||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||
if img.mode == 'P':
|
||
img = img.convert('RGBA')
|
||
if img.mode in ('RGBA', 'LA'):
|
||
background.paste(img, mask=img.split()[-1])
|
||
img = background
|
||
elif img.mode != 'RGB':
|
||
img = img.convert('RGB')
|
||
|
||
width, height = img.size
|
||
from PIL import ImageDraw, ImageFont
|
||
draw = ImageDraw.Draw(img)
|
||
|
||
try:
|
||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 36)
|
||
except:
|
||
font = ImageFont.load_default()
|
||
|
||
text = watermark_text
|
||
bbox = draw.textbbox((0, 0), text, font=font)
|
||
text_width = bbox[2] - bbox[0]
|
||
text_height = bbox[3] - bbox[1]
|
||
|
||
x = width - text_width - 20
|
||
y = height - text_height - 20
|
||
|
||
draw.text((x, y), text, fill=(255, 255, 255, 128), font=font)
|
||
img.save(watermarked_path, quality=90)
|
||
|
||
|
||
@router.post("", summary="上传作品")
|
||
async def upload_work(
|
||
file: UploadFile = File(..., description="作品图片文件"),
|
||
title: Optional[str] = Form(None, description="作品标题"),
|
||
description: Optional[str] = Form(None, description="作品描述"),
|
||
category: str = Form(..., description="作品分类"),
|
||
tags: Optional[str] = Form(None, description="标签,逗号分隔"),
|
||
designer_name: Optional[str] = Form(None, description="归属设计师"),
|
||
price: float = Form(..., ge=0, description="作品价格"),
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""
|
||
📤 上传作品 API
|
||
|
||
**需要登录**:是(Bearer Token)
|
||
|
||
**支持格式**:jpg, jpeg, png, gif, webp
|
||
|
||
**文件大小**:最大 50MB
|
||
"""
|
||
# 验证文件扩展名
|
||
file_ext = os.path.splitext(file.filename)[1].lower()
|
||
if file_ext not in ALLOWED_EXTENSIONS:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"不支持的文件格式,仅支持:{', '.join(ALLOWED_EXTENSIONS)}"
|
||
)
|
||
|
||
# 验证文件大小
|
||
file.file.seek(0, 2)
|
||
file_size = file.file.tell()
|
||
file.file.seek(0)
|
||
|
||
if file_size > MAX_FILE_SIZE:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"文件过大,最大支持 {MAX_FILE_SIZE // 1024 // 1024}MB"
|
||
)
|
||
|
||
# 生成唯一文件名
|
||
unique_id = str(uuid.uuid4())
|
||
timestamp = datetime.now().strftime("%Y%m%d")
|
||
|
||
# 创建上传目录
|
||
upload_dir = os.path.join(UPLOAD_BASE_DIR, "original", timestamp)
|
||
thumb_dir = os.path.join(UPLOAD_BASE_DIR, "thumbnail", timestamp)
|
||
watermarked_dir = os.path.join(UPLOAD_BASE_DIR, "watermarked", timestamp)
|
||
|
||
os.makedirs(upload_dir, exist_ok=True)
|
||
os.makedirs(thumb_dir, exist_ok=True)
|
||
os.makedirs(watermarked_dir, exist_ok=True)
|
||
|
||
# 保存文件
|
||
original_filename = f"{unique_id}{file_ext}"
|
||
original_path = os.path.join(upload_dir, original_filename)
|
||
thumb_filename = f"{unique_id}_thumb.jpg"
|
||
thumb_path = os.path.join(thumb_dir, thumb_filename)
|
||
watermarked_filename = f"{unique_id}_watermarked.jpg"
|
||
watermarked_path = os.path.join(watermarked_dir, watermarked_filename)
|
||
|
||
try:
|
||
# 保存原图
|
||
with open(original_path, "wb") as buffer:
|
||
shutil.copyfileobj(file.file, buffer)
|
||
|
||
# 生成缩略图
|
||
generate_thumbnail(original_path, thumb_path)
|
||
|
||
# 生成水印图
|
||
add_watermark(original_path, watermarked_path)
|
||
|
||
# 解析标签 - 修复 Bug #3: tags 转字符串
|
||
if tags:
|
||
tags_list = [tag.strip() for tag in tags.split(",") if tag.strip()]
|
||
tags_str = ",".join(tags_list) # 转成逗号分隔的字符串
|
||
else:
|
||
tags_str = None
|
||
resolved_designer = str(designer_name or "").strip() or current_user.nickname or current_user.phone
|
||
resolved_title = str(title or "").strip() or _generate_fallback_title(category, resolved_designer, unique_id)
|
||
|
||
# 创建数据库记录 - 修复所有字段问题
|
||
work = Work(
|
||
title=resolved_title,
|
||
description=description or "",
|
||
category=category,
|
||
tags=tags_str, # 修复:使用字符串而不是列表
|
||
price=price,
|
||
designer=resolved_designer,
|
||
original_image=f"/uploads/original/{timestamp}/{original_filename}",
|
||
thumbnail_image=f"/uploads/thumbnail/{timestamp}/{thumb_filename}",
|
||
watermarked_image=f"/uploads/watermarked/{timestamp}/{watermarked_filename}"
|
||
)
|
||
|
||
db.add(work)
|
||
db.commit()
|
||
db.refresh(work)
|
||
|
||
# 构建完整的图片 URL
|
||
image_url = f"{TUHUI_DOMAIN}{work.original_image}"
|
||
thumbnail_url = f"{TUHUI_DOMAIN}{work.thumbnail_image}"
|
||
watermarked_url = f"{TUHUI_DOMAIN}{work.watermarked_image}"
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "上传成功,等待审核",
|
||
"work_id": work.id,
|
||
"image_url": image_url,
|
||
"thumbnail_url": thumbnail_url,
|
||
"watermarked_url": watermarked_url
|
||
}
|
||
|
||
except Exception as e:
|
||
# 清理已上传的文件
|
||
for path in [original_path, thumb_path, watermarked_path]:
|
||
if os.path.exists(path):
|
||
os.remove(path)
|
||
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"上传失败:{str(e)}"
|
||
)
|
||
|
||
|
||
@router.get("/my", summary="我的上传")
|
||
def get_my_uploads(
|
||
page: int = Query(1, ge=1),
|
||
page_size: int = Query(20, ge=1, le=100),
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取当前用户的上传记录"""
|
||
aliases = _user_designer_aliases(current_user)
|
||
offset = (page - 1) * page_size
|
||
query = db.query(Work)
|
||
if aliases:
|
||
query = query.filter(or_(*[Work.designer == alias for alias in aliases]))
|
||
else:
|
||
query = query.filter(Work.id == -1)
|
||
|
||
works = query.order_by(desc(Work.created_at)).offset(offset).limit(page_size).all()
|
||
total = query.count()
|
||
|
||
return {
|
||
"total": total,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"items": works
|
||
}
|