from fastapi import APIRouter, Depends, HTTPException, status, Query, Header, BackgroundTasks from fastapi.responses import FileResponse from sqlalchemy.orm import Session from sqlalchemy import desc from typing import List import os import mimetypes from urllib.parse import urlparse from app.core.database import get_db from app.models.work import Work from app.models.order import Order, OrderStatus from app.models.user import User from app.core.security import decode_access_token from app.core.config import settings from app.schemas.work import WorkResponse, WorkListResponse from app.services.download_tracker import notify_download router = APIRouter(prefix="/works", tags=["作品"]) def resolve_upload_file_path(image_path: str) -> str: """兼容 /uploads 相对路径和完整 URL 两种历史存储格式。""" if not image_path: return "" normalized = image_path.strip() if normalized.startswith(("http://", "https://")): normalized = urlparse(normalized).path or "" normalized = normalized.lstrip("/") if normalized.startswith("uploads/"): normalized = normalized[len("uploads/"):] return os.path.join(settings.UPLOAD_DIR, normalized) @router.get("", response_model=WorkListResponse, summary="获取作品列表") def get_works( page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), category: str = None, keyword: str = None, db: Session = Depends(get_db) ): """ 获取作品列表 - page: 页码 - page_size: 每页数量 - category: 分类筛选 - keyword: 关键词搜索 """ query = db.query(Work) # 分类筛选 if category: query = query.filter(Work.category == category) # 关键词搜索 if keyword: query = query.filter(Work.title.contains(keyword)) # 获取总数 total = query.count() # 分页查询 offset = (page - 1) * page_size works = query.order_by(desc(Work.created_at)).offset(offset).limit(page_size).all() # 修复:将 level_text 为 NULL 的转为空字符串 for work in works: if work.level_text is None: work.level_text = "" return WorkListResponse( total=total, items=works ) @router.get("/{work_id}", response_model=WorkResponse, summary="获取作品详情") def get_work(work_id: int, db: Session = Depends(get_db)): """获取作品详情""" work = db.query(Work).filter(Work.id == work_id).first() if not work: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="作品不存在" ) # 增加浏览量 work.views += 1 db.commit() # 修复:将 level_text 为 NULL 的转为空字符串 if work.level_text is None: work.level_text = "" return work def get_current_user(authorization: str = Header(None), db: Session = Depends(get_db)): """获取当前登录用户(用于下载验证)""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未登录,请先登录" ) token = authorization.replace("Bearer ", "") payload = decode_access_token(token) if not payload: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token 无效或已过期" ) user_id = int(payload.get("sub")) user = db.query(User).filter(User.id == user_id).first() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在" ) return user @router.get("/{work_id}/download", summary="下载作品原图") def download_work( work_id: int, background_tasks: BackgroundTasks, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ 下载作品原图 - 需要登录 - 必须已支付订单 - 支付成功后才能下载 """ # 查找作品 work = db.query(Work).filter(Work.id == work_id).first() if not work: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="作品不存在" ) # 检查是否已购买(查询已支付的订单) paid_order = db.query(Order).filter( Order.user_id == current_user.id, Order.work_id == work_id, Order.status == OrderStatus.PAID ).first() if not paid_order: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="您还未购买此作品,请先完成支付" ) # 构建原图文件路径 full_path = resolve_upload_file_path(work.original_image) # 检查文件是否存在 if not os.path.exists(full_path): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="文件不存在" ) background_tasks.add_task(notify_download, work, current_user, paid_order) # 返回文件 filename = os.path.basename(full_path) media_type = mimetypes.guess_type(filename)[0] or "application/octet-stream" return FileResponse( path=full_path, filename=filename, media_type=media_type )