Files
tw/scripts/test_image_pipeline_e2e.py

145 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import os
from datetime import datetime
from pathlib import Path
import httpx
from dotenv import load_dotenv
from PIL import Image, ImageDraw
from services.service_gemini import GeminiExtractV2Service
from services.service_tuhui_upload import upload_to_tuhui
load_dotenv()
BASE_DIR = Path(__file__).resolve().parents[1] / "tmp_e2e_pipeline"
BASE_DIR.mkdir(parents=True, exist_ok=True)
TUHUI_WEB_BASE = os.getenv("TUHUI_WEB_BASE_URL", "https://aidg168.uk").rstrip("/")
TUHUI_DIRECT_BASE = os.getenv("TUHUI_DIRECT_BASE_URL", "http://156.226.181.204:8002").rstrip("/")
TUHUI_API_BASES = [f"{TUHUI_DIRECT_BASE}/api", f"{TUHUI_WEB_BASE}/api"]
TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271")
TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216")
def build_test_input() -> Path:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
path = BASE_DIR / f"input_{ts}.png"
img = Image.new("RGBA", (768, 768), (228, 244, 249, 255))
draw = ImageDraw.Draw(img)
draw.rounded_rectangle((40, 40, 728, 728), radius=48, fill=(33, 114, 147, 255))
draw.ellipse((120, 120, 340, 340), fill=(96, 214, 255, 255))
draw.rectangle((390, 160, 640, 460), fill=(11, 54, 72, 255))
draw.text((110, 540), "TW Gemini E2E Test", fill=(255, 255, 255, 255))
img.save(path)
return path
async def login_tuhui(client: httpx.AsyncClient) -> tuple[str, str]:
last_error = None
for api_base in TUHUI_API_BASES:
try:
resp = await client.post(
f"{api_base}/auth/login",
json={"phone": TUHUI_PHONE, "password": TUHUI_PASSWORD},
timeout=30.0,
)
resp.raise_for_status()
data = resp.json()
return data["access_token"], api_base
except Exception as e:
last_error = e
raise last_error or RuntimeError("图绘登录失败")
async def create_order_and_pay(client: httpx.AsyncClient, api_base: str, token: str, work_id: int) -> dict:
headers = {"Authorization": f"Bearer {token}"}
order_resp = await client.post(
f"{api_base}/orders/create",
headers=headers,
json={"work_id": work_id, "payment_method": "balance"},
timeout=30.0,
)
order_resp.raise_for_status()
order = order_resp.json()
pay_resp = await client.post(
f"{api_base}/orders/pay/{order['id']}",
headers=headers,
timeout=30.0,
)
pay_resp.raise_for_status()
payment = pay_resp.json()
return {"order": order, "payment": payment}
async def download_work(client: httpx.AsyncClient, api_base: str, token: str, work_id: int) -> Path:
headers = {"Authorization": f"Bearer {token}"}
resp = await client.get(
f"{api_base}/works/{work_id}/download",
headers=headers,
timeout=60.0,
)
resp.raise_for_status()
filename = resp.headers.get("content-disposition", "").split("filename=")[-1].strip('"') or f"work_{work_id}.bin"
dest = BASE_DIR / filename
dest.write_bytes(resp.content)
return dest
async def main():
print("== 构建测试图 ==")
input_path = build_test_input()
output_path = BASE_DIR / f"gemini_{input_path.stem}.png"
print(f"输入图: {input_path}")
print("== Gemini 处理 ==")
gemini = GeminiExtractV2Service()
success, message, data = await gemini.extract_pattern(
str(input_path),
str(output_path),
custom_prompt="根据原图生成一张更完整、更干净的科技风背景素材图,保持主体布局清晰。",
aspect_ratio="1:1",
)
print({"success": success, "message": message, "data": data})
if not success:
raise RuntimeError(f"Gemini 处理失败: {message}")
print("== 上传图绘 ==")
title = f"E2E测试图_{datetime.now().strftime('%m%d_%H%M%S')}"
upload_result = await upload_to_tuhui(
image_path=str(output_path),
title=title,
description="自动化端到端测试图0元下载校验。",
price=0,
category="设计素材",
tags="E2E测试,自动化",
designer_name="自动化测试",
)
print(upload_result.as_dict())
if not upload_result.success:
raise RuntimeError(f"图绘上传失败: {upload_result.message}")
print("== 登录图绘并创建0元订单 ==")
async with httpx.AsyncClient(follow_redirects=True) as client:
token, api_base = await login_tuhui(client)
pay_info = await create_order_and_pay(client, api_base, token, upload_result.work_id)
print(pay_info)
print("== 下载作品 ==")
downloaded = await download_work(client, api_base, token, upload_result.work_id)
print({"downloaded_file": str(downloaded), "size": downloaded.stat().st_size})
print("== 测试完成 ==")
print(
{
"work_id": upload_result.work_id,
"detail_url": upload_result.download_url,
"processed_output": str(output_path),
}
)
if __name__ == "__main__":
asyncio.run(main())