Files
tw/services/service_gemini.py

224 lines
8.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Gemini 出图服务。固定走老张 Gemini 原生出图接口。"""
import asyncio
import base64
import logging
import mimetypes
import os
from typing import Dict
import aiohttp
from dotenv import load_dotenv
from utils.metrics_tracker import emit as metrics_emit
from utils.service_base import BaseService
logger = logging.getLogger(__name__)
load_dotenv()
GEMINI_API_KEY = os.getenv(
"GEMINI_API_KEY",
"sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
)
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview")
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "2K")
GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "")
class GeminiExtractV2Service(BaseService):
"""固定单接口的 Gemini 出图服务。"""
SERVICE_NAME = "gemini_extract_v2"
API_BASE_URL = "https://api.laozhang.ai/v1beta/models"
DEFAULT_PROMPT = (
"提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富,"
"严格按照原图的元素位置生成平面的印花图不要相似的相似度要100%,生成高质量的印刷图"
)
def __init__(self):
super().__init__(name=self.SERVICE_NAME)
@staticmethod
def _image_to_base64(image_path: str) -> str:
if not os.path.exists(image_path):
logger.error(f"文件不存在: {image_path}")
return ""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except Exception as e:
logger.error(f"Base64转换失败: {e}")
return ""
@staticmethod
def _guess_mime_type(image_path: str) -> str:
mime_type, _ = mimetypes.guess_type(str(image_path))
return mime_type or "image/png"
@staticmethod
def _build_generation_config(
aspect_ratio: str,
image_size: str,
person_generation: str,
thinking_level: str,
) -> Dict:
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
valid_sizes = {"1K", "2K", "4K"}
image_config = {}
if aspect_ratio in valid_ratios:
image_config["aspectRatio"] = aspect_ratio
size_val = (image_size or GEMINI_IMAGE_SIZE or "").upper().strip()
if size_val in valid_sizes:
image_config["imageSize"] = size_val
person_val = (person_generation or GEMINI_PERSON_GENERATION or "").strip()
if person_val:
image_config["personGeneration"] = person_val
generation_config = {"responseModalities": ["IMAGE"]}
if image_config:
generation_config["imageConfig"] = image_config
return generation_config
@staticmethod
def _extract_image_bytes(result: Dict) -> bytes:
candidates = result.get("candidates") or []
if not candidates:
raise ValueError("响应缺少 candidates")
parts = ((candidates[0] or {}).get("content") or {}).get("parts") or []
for part in parts:
inline_data = part.get("inlineData") or {}
encoded = inline_data.get("data")
if encoded:
return base64.b64decode(encoded)
finish_reason = candidates[0].get("finishReason") or ""
if finish_reason == "NO_IMAGE":
raise ValueError("模型未返回图片(NO_IMAGE)")
raise ValueError("响应中未找到 inlineData 图片")
@staticmethod
def _save_image(image_data: bytes, output_path: str) -> Dict:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "wb") as f:
f.write(image_data)
file_size = os.path.getsize(output_path)
return {
"output_path": output_path,
"file_size": file_size,
"api_used": "laozhang_gemini_native",
}
async def extract_pattern(
self,
input_path: str,
output_path: str,
custom_prompt: str = None,
aspect_ratio: str = "1:1",
image_size: str = "",
person_generation: str = "",
thinking_level: str = "",
) -> tuple[bool, str, dict]:
img64 = self._image_to_base64(input_path)
if not img64:
return False, "图片编码失败", {}
prompt = str(custom_prompt or self.DEFAULT_PROMPT).strip()
api_url = f"{self.API_BASE_URL}/{GEMINI_IMAGE_MODEL}:generateContent"
headers = {
"Authorization": f"Bearer {GEMINI_API_KEY}",
"Content-Type": "application/json",
}
payload = {
"contents": [
{
"parts": [
{"inlineData": {"mimeType": self._guess_mime_type(input_path), "data": img64}},
{"text": prompt},
]
}
],
"generationConfig": self._build_generation_config(
aspect_ratio=aspect_ratio,
image_size=image_size,
person_generation=person_generation,
thinking_level=thinking_level,
),
}
metrics_emit("gemini_request", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native")
timeout = aiohttp.ClientTimeout(total=300, connect=30)
for attempt in range(1, 3):
try:
logger.info(f"Gemini 出图开始 attempt={attempt}/2 model={GEMINI_IMAGE_MODEL} input={input_path}")
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(api_url, headers=headers, json=payload) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Gemini API请求失败 attempt={attempt}: {response.status} - {error_text}")
if attempt < 2:
await asyncio.sleep(attempt)
continue
return False, f"Gemini API请求失败: {response.status}", {}
result = await response.json()
image_bytes = self._extract_image_bytes(result)
data = self._save_image(image_bytes, output_path)
metrics_emit("gemini_success", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native")
try:
from utils.api_cost_tracker import record
record("gemini_extract", count=1)
except Exception:
pass
return True, "Gemini 出图完成", data
except Exception as e:
logger.error(f"Gemini 出图异常 attempt={attempt}: {e}")
if attempt < 2:
await asyncio.sleep(attempt)
continue
return False, f"Gemini 出图失败: {e}", {}
return False, "Gemini 出图失败", {}
async def correct_perspective(
self,
input_path: str,
output_path: str,
level: str = "mild",
) -> tuple[bool, str, dict]:
if level == "strong":
prompt = (
"这张图存在明显透视畸变。请把主体矫正为正面平铺视图,"
"所有边缘尽量水平或垂直,保持图案颜色和细节不变,只做几何矫正。"
)
else:
prompt = (
"这张图存在轻微透视畸变。请做轻度透视矫正,"
"消除斜拍拉伸感,保持图案颜色和细节不变。"
)
return await self.extract_pattern(
input_path=input_path,
output_path=output_path,
custom_prompt=prompt,
aspect_ratio="1:1",
)
async def cleanup(self):
return None
async def extract_pattern_v2(
input_path: str,
output_path: str,
custom_prompt: str = None,
aspect_ratio: str = "1:1",
) -> tuple[bool, str, dict]:
service = GeminiExtractV2Service()
try:
return await service.extract_pattern(input_path, output_path, custom_prompt, aspect_ratio)
finally:
await service.cleanup()