from fastapi import APIRouter, Depends, HTTPException, status, Header from sqlalchemy.orm import Session from datetime import datetime, timedelta import secrets from app.core.database import get_db from app.core.security import decode_access_token from app.models.user import User from app.models.work import Work from app.models.order import Order, OrderStatus from app.models.download import DownloadRecord from app.schemas.order import OrderCreate, OrderResponse, PaymentResponse from app.services.download_tracker import record_download router = APIRouter(prefix="/orders", tags=["订单"]) def _generate_order_no(db: Session) -> str: """生成带随机后缀的唯一订单号,避免同秒冲突。""" now = datetime.now() six_months_ago = now - timedelta(days=180) date_part = six_months_ago.strftime('%Y%m%d') time_part = now.strftime('%H%M%S') for _ in range(5): suffix = secrets.token_hex(3).upper() order_no = f"ORD{date_part}{time_part}{suffix}" exists = db.query(Order.id).filter(Order.order_no == order_no).first() if not exists: return order_no raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="生成订单号失败,请稍后重试", ) def get_current_user(authorization: str = Header(None), db: Session = Depends(get_db)): """获取当前登录用户""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="未登录" ) token = authorization.replace("Bearer ", "") payload = decode_access_token(token) if not payload: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token 无效或已过期" ) user_id = int(payload.get("sub")) user = db.query(User).filter(User.id == user_id).first() if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在" ) return user @router.post("/create", response_model=OrderResponse, summary="创建订单") def create_order( order_data: OrderCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """创建下载订单""" # 查找作品 work = db.query(Work).filter(Work.id == order_data.work_id).first() if not work: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="作品不存在" ) # 检查是否已购买 existing_order = db.query(Order).filter( Order.user_id == current_user.id, Order.work_id == work.id, Order.status == OrderStatus.PAID ).first() if existing_order: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="您已购买过此作品" ) # 生成唯一订单号:半年前日期 + 当前时分秒 + 随机后缀 order_no = _generate_order_no(db) # 创建订单 new_order = Order( order_no=order_no, user_id=current_user.id, work_id=work.id, amount=work.price, payment_method=order_data.payment_method, status=OrderStatus.PENDING ) db.add(new_order) db.commit() db.refresh(new_order) return new_order @router.post("/pay/{order_id}", response_model=PaymentResponse, summary="支付订单") def pay_order( order_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """ 支付订单(模拟支付) 实际生产环境需要对接真实支付接口 """ # 查找订单 order = db.query(Order).filter( Order.id == order_id, Order.user_id == current_user.id ).first() if not order: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="订单不存在" ) if order.status != OrderStatus.PENDING: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="订单状态异常" ) # 余额支付 if order.payment_method == "balance": if current_user.balance < order.amount: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="余额不足" ) # 扣除余额 current_user.balance -= order.amount # 更新订单状态 order.status = OrderStatus.PAID order.paid_at = datetime.now() work = db.query(Work).filter(Work.id == order.work_id).first() if work: record_download(db, order, work, current_user) db.commit() return PaymentResponse( success=True, message="支付成功", order_no=order.order_no, download_url=f"/api/works/{work.id}/download" ) # 其他支付方式(支付宝、微信) else: return PaymentResponse( success=False, message="暂不支持该支付方式,请使用余额支付或联系管理员", order_no=order.order_no ) @router.get("/my", summary="我的订单") def get_my_orders( current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """获取当前用户的订单列表""" orders = db.query(Order).filter(Order.user_id == current_user.id).all() return orders