fix: streamline gemini flow and add e2e test
This commit is contained in:
144
scripts/test_image_pipeline_e2e.py
Normal file
144
scripts/test_image_pipeline_e2e.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
from services.service_gemini import GeminiExtractV2Service
|
||||||
|
from services.service_tuhui_upload import upload_to_tuhui
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).resolve().parents[1] / "tmp_e2e_pipeline"
|
||||||
|
BASE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
TUHUI_WEB_BASE = os.getenv("TUHUI_WEB_BASE_URL", "https://tuhui.cloud").rstrip("/")
|
||||||
|
TUHUI_DIRECT_BASE = os.getenv("TUHUI_DIRECT_BASE_URL", "http://1.12.50.92:8002").rstrip("/")
|
||||||
|
TUHUI_API_BASES = [f"{TUHUI_DIRECT_BASE}/api", f"{TUHUI_WEB_BASE}/api"]
|
||||||
|
TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271")
|
||||||
|
TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216")
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_input() -> Path:
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
path = BASE_DIR / f"input_{ts}.png"
|
||||||
|
img = Image.new("RGBA", (768, 768), (228, 244, 249, 255))
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.rounded_rectangle((40, 40, 728, 728), radius=48, fill=(33, 114, 147, 255))
|
||||||
|
draw.ellipse((120, 120, 340, 340), fill=(96, 214, 255, 255))
|
||||||
|
draw.rectangle((390, 160, 640, 460), fill=(11, 54, 72, 255))
|
||||||
|
draw.text((110, 540), "TW Gemini E2E Test", fill=(255, 255, 255, 255))
|
||||||
|
img.save(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
async def login_tuhui(client: httpx.AsyncClient) -> tuple[str, str]:
|
||||||
|
last_error = None
|
||||||
|
for api_base in TUHUI_API_BASES:
|
||||||
|
try:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{api_base}/auth/login",
|
||||||
|
json={"phone": TUHUI_PHONE, "password": TUHUI_PASSWORD},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
return data["access_token"], api_base
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
raise last_error or RuntimeError("图绘登录失败")
|
||||||
|
|
||||||
|
|
||||||
|
async def create_order_and_pay(client: httpx.AsyncClient, api_base: str, token: str, work_id: int) -> dict:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
order_resp = await client.post(
|
||||||
|
f"{api_base}/orders/create",
|
||||||
|
headers=headers,
|
||||||
|
json={"work_id": work_id, "payment_method": "balance"},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
order_resp.raise_for_status()
|
||||||
|
order = order_resp.json()
|
||||||
|
|
||||||
|
pay_resp = await client.post(
|
||||||
|
f"{api_base}/orders/pay/{order['id']}",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
pay_resp.raise_for_status()
|
||||||
|
payment = pay_resp.json()
|
||||||
|
return {"order": order, "payment": payment}
|
||||||
|
|
||||||
|
|
||||||
|
async def download_work(client: httpx.AsyncClient, api_base: str, token: str, work_id: int) -> Path:
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
resp = await client.get(
|
||||||
|
f"{api_base}/works/{work_id}/download",
|
||||||
|
headers=headers,
|
||||||
|
timeout=60.0,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
filename = resp.headers.get("content-disposition", "").split("filename=")[-1].strip('"') or f"work_{work_id}.bin"
|
||||||
|
dest = BASE_DIR / filename
|
||||||
|
dest.write_bytes(resp.content)
|
||||||
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
print("== 构建测试图 ==")
|
||||||
|
input_path = build_test_input()
|
||||||
|
output_path = BASE_DIR / f"gemini_{input_path.stem}.png"
|
||||||
|
print(f"输入图: {input_path}")
|
||||||
|
|
||||||
|
print("== Gemini 处理 ==")
|
||||||
|
gemini = GeminiExtractV2Service()
|
||||||
|
success, message, data = await gemini.extract_pattern(
|
||||||
|
str(input_path),
|
||||||
|
str(output_path),
|
||||||
|
custom_prompt="根据原图生成一张更完整、更干净的科技风背景素材图,保持主体布局清晰。",
|
||||||
|
aspect_ratio="1:1",
|
||||||
|
)
|
||||||
|
print({"success": success, "message": message, "data": data})
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError(f"Gemini 处理失败: {message}")
|
||||||
|
|
||||||
|
print("== 上传图绘 ==")
|
||||||
|
title = f"E2E测试图_{datetime.now().strftime('%m%d_%H%M%S')}"
|
||||||
|
upload_result = await upload_to_tuhui(
|
||||||
|
image_path=str(output_path),
|
||||||
|
title=title,
|
||||||
|
description="自动化端到端测试图,0元下载校验。",
|
||||||
|
price=0,
|
||||||
|
category="设计素材",
|
||||||
|
tags="E2E测试,自动化",
|
||||||
|
designer_name="自动化测试",
|
||||||
|
)
|
||||||
|
print(upload_result.as_dict())
|
||||||
|
if not upload_result.success:
|
||||||
|
raise RuntimeError(f"图绘上传失败: {upload_result.message}")
|
||||||
|
|
||||||
|
print("== 登录图绘并创建0元订单 ==")
|
||||||
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||||
|
token, api_base = await login_tuhui(client)
|
||||||
|
pay_info = await create_order_and_pay(client, api_base, token, upload_result.work_id)
|
||||||
|
print(pay_info)
|
||||||
|
|
||||||
|
print("== 下载作品 ==")
|
||||||
|
downloaded = await download_work(client, api_base, token, upload_result.work_id)
|
||||||
|
print({"downloaded_file": str(downloaded), "size": downloaded.stat().st_size})
|
||||||
|
|
||||||
|
print("== 测试完成 ==")
|
||||||
|
print(
|
||||||
|
{
|
||||||
|
"work_id": upload_result.work_id,
|
||||||
|
"detail_url": upload_result.download_url,
|
||||||
|
"processed_output": str(output_path),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -24,7 +24,6 @@ GEMINI_API_KEY = os.getenv(
|
|||||||
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview")
|
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||||||
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "2K")
|
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "2K")
|
||||||
GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "")
|
GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "")
|
||||||
GEMINI_THINKING_LEVEL = os.getenv("GEMINI_THINKING_LEVEL", "MINIMAL")
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtractV2Service(BaseService):
|
class GeminiExtractV2Service(BaseService):
|
||||||
@@ -66,7 +65,6 @@ class GeminiExtractV2Service(BaseService):
|
|||||||
) -> Dict:
|
) -> Dict:
|
||||||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||||||
valid_sizes = {"1K", "2K", "4K"}
|
valid_sizes = {"1K", "2K", "4K"}
|
||||||
valid_thinking = {"MINIMAL", "LOW", "MEDIUM", "HIGH"}
|
|
||||||
|
|
||||||
image_config = {}
|
image_config = {}
|
||||||
if aspect_ratio in valid_ratios:
|
if aspect_ratio in valid_ratios:
|
||||||
@@ -82,10 +80,6 @@ class GeminiExtractV2Service(BaseService):
|
|||||||
if image_config:
|
if image_config:
|
||||||
generation_config["imageConfig"] = image_config
|
generation_config["imageConfig"] = image_config
|
||||||
|
|
||||||
thinking_val = (thinking_level or GEMINI_THINKING_LEVEL or "").upper().strip()
|
|
||||||
if thinking_val in valid_thinking:
|
|
||||||
generation_config["thinkingConfig"] = {"thinkingLevel": thinking_val}
|
|
||||||
|
|
||||||
return generation_config
|
return generation_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ load_dotenv()
|
|||||||
# 图绘平台配置
|
# 图绘平台配置
|
||||||
TUHUI_BASE_URL = os.getenv("TUHUI_BASE_URL", "https://tuhui.cloud")
|
TUHUI_BASE_URL = os.getenv("TUHUI_BASE_URL", "https://tuhui.cloud")
|
||||||
TUHUI_FALLBACK_BASE_URL = "https://tuhui.cloud"
|
TUHUI_FALLBACK_BASE_URL = "https://tuhui.cloud"
|
||||||
|
TUHUI_DIRECT_BASE_URL = os.getenv("TUHUI_DIRECT_BASE_URL", "http://1.12.50.92:8002")
|
||||||
TUHUI_WEB_BASE_URL = os.getenv("TUHUI_WEB_BASE_URL", "https://tuhui.cloud").rstrip("/")
|
TUHUI_WEB_BASE_URL = os.getenv("TUHUI_WEB_BASE_URL", "https://tuhui.cloud").rstrip("/")
|
||||||
TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271") # 图绘账号手机号
|
TUHUI_PHONE = os.getenv("TUHUI_PHONE", "17520145271") # 图绘账号手机号
|
||||||
TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216") # 图绘账号密码
|
TUHUI_PASSWORD = os.getenv("TUHUI_PASSWORD", "zuowei1216") # 图绘账号密码
|
||||||
@@ -59,7 +60,11 @@ class TuhuiUploadService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.base_url = TUHUI_BASE_URL.rstrip("/")
|
self.base_url = TUHUI_BASE_URL.rstrip("/")
|
||||||
self.base_urls = []
|
self.base_urls = []
|
||||||
for candidate in (TUHUI_FALLBACK_BASE_URL.rstrip("/"), self.base_url):
|
for candidate in (
|
||||||
|
TUHUI_FALLBACK_BASE_URL.rstrip("/"),
|
||||||
|
TUHUI_DIRECT_BASE_URL.rstrip("/"),
|
||||||
|
self.base_url,
|
||||||
|
):
|
||||||
if candidate and candidate not in self.base_urls:
|
if candidate and candidate not in self.base_urls:
|
||||||
self.base_urls.append(candidate)
|
self.base_urls.append(candidate)
|
||||||
if self.base_urls:
|
if self.base_urls:
|
||||||
@@ -160,7 +165,7 @@ class TuhuiUploadService:
|
|||||||
return TuhuiUploadResult(False, "", 0, message="登录失败")
|
return TuhuiUploadResult(False, "", 0, message="登录失败")
|
||||||
|
|
||||||
# 准备上传数据
|
# 准备上传数据
|
||||||
price = price or self.default_price
|
price = self.default_price if price is None else price
|
||||||
|
|
||||||
# 读取图片文件
|
# 读取图片文件
|
||||||
if not os.path.exists(image_path):
|
if not os.path.exists(image_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user