230 lines
8.5 KiB
Python
Executable File
230 lines
8.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""Gemini 出图服务。固定走老张 Gemini 原生出图接口。"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import logging
|
||
import mimetypes
|
||
import os
|
||
from typing import Dict
|
||
|
||
import aiohttp
|
||
from dotenv import load_dotenv
|
||
|
||
from utils.metrics_tracker import emit as metrics_emit
|
||
from utils.service_base import BaseService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
load_dotenv()
|
||
|
||
GEMINI_API_KEY = os.getenv(
|
||
"GEMINI_API_KEY",
|
||
"sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
|
||
)
|
||
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "2K")
|
||
GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "")
|
||
GEMINI_THINKING_LEVEL = os.getenv("GEMINI_THINKING_LEVEL", "MINIMAL")
|
||
|
||
|
||
class GeminiExtractV2Service(BaseService):
|
||
"""固定单接口的 Gemini 出图服务。"""
|
||
|
||
SERVICE_NAME = "gemini_extract_v2"
|
||
API_BASE_URL = "https://api.laozhang.ai/v1beta/models"
|
||
DEFAULT_PROMPT = (
|
||
"提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富,"
|
||
"严格按照原图的元素位置生成平面的印花图,不要相似的,相似度要100%,生成高质量的印刷图"
|
||
)
|
||
|
||
def __init__(self):
|
||
super().__init__(name=self.SERVICE_NAME)
|
||
|
||
@staticmethod
|
||
def _image_to_base64(image_path: str) -> str:
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"文件不存在: {image_path}")
|
||
return ""
|
||
try:
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||
except Exception as e:
|
||
logger.error(f"Base64转换失败: {e}")
|
||
return ""
|
||
|
||
@staticmethod
|
||
def _guess_mime_type(image_path: str) -> str:
|
||
mime_type, _ = mimetypes.guess_type(str(image_path))
|
||
return mime_type or "image/png"
|
||
|
||
@staticmethod
|
||
def _build_generation_config(
|
||
aspect_ratio: str,
|
||
image_size: str,
|
||
person_generation: str,
|
||
thinking_level: str,
|
||
) -> Dict:
|
||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||
valid_sizes = {"1K", "2K", "4K"}
|
||
valid_thinking = {"MINIMAL", "LOW", "MEDIUM", "HIGH"}
|
||
|
||
image_config = {}
|
||
if aspect_ratio in valid_ratios:
|
||
image_config["aspectRatio"] = aspect_ratio
|
||
size_val = (image_size or GEMINI_IMAGE_SIZE or "").upper().strip()
|
||
if size_val in valid_sizes:
|
||
image_config["imageSize"] = size_val
|
||
person_val = (person_generation or GEMINI_PERSON_GENERATION or "").strip()
|
||
if person_val:
|
||
image_config["personGeneration"] = person_val
|
||
|
||
generation_config = {"responseModalities": ["IMAGE"]}
|
||
if image_config:
|
||
generation_config["imageConfig"] = image_config
|
||
|
||
thinking_val = (thinking_level or GEMINI_THINKING_LEVEL or "").upper().strip()
|
||
if thinking_val in valid_thinking:
|
||
generation_config["thinkingConfig"] = {"thinkingLevel": thinking_val}
|
||
|
||
return generation_config
|
||
|
||
@staticmethod
|
||
def _extract_image_bytes(result: Dict) -> bytes:
|
||
candidates = result.get("candidates") or []
|
||
if not candidates:
|
||
raise ValueError("响应缺少 candidates")
|
||
parts = ((candidates[0] or {}).get("content") or {}).get("parts") or []
|
||
for part in parts:
|
||
inline_data = part.get("inlineData") or {}
|
||
encoded = inline_data.get("data")
|
||
if encoded:
|
||
return base64.b64decode(encoded)
|
||
finish_reason = candidates[0].get("finishReason") or ""
|
||
if finish_reason == "NO_IMAGE":
|
||
raise ValueError("模型未返回图片(NO_IMAGE)")
|
||
raise ValueError("响应中未找到 inlineData 图片")
|
||
|
||
@staticmethod
|
||
def _save_image(image_data: bytes, output_path: str) -> Dict:
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
with open(output_path, "wb") as f:
|
||
f.write(image_data)
|
||
file_size = os.path.getsize(output_path)
|
||
return {
|
||
"output_path": output_path,
|
||
"file_size": file_size,
|
||
"api_used": "laozhang_gemini_native",
|
||
}
|
||
|
||
async def extract_pattern(
|
||
self,
|
||
input_path: str,
|
||
output_path: str,
|
||
custom_prompt: str = None,
|
||
aspect_ratio: str = "1:1",
|
||
image_size: str = "",
|
||
person_generation: str = "",
|
||
thinking_level: str = "",
|
||
) -> tuple[bool, str, dict]:
|
||
img64 = self._image_to_base64(input_path)
|
||
if not img64:
|
||
return False, "图片编码失败", {}
|
||
|
||
prompt = str(custom_prompt or self.DEFAULT_PROMPT).strip()
|
||
api_url = f"{self.API_BASE_URL}/{GEMINI_IMAGE_MODEL}:generateContent"
|
||
headers = {
|
||
"Authorization": f"Bearer {GEMINI_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload = {
|
||
"contents": [
|
||
{
|
||
"parts": [
|
||
{"inlineData": {"mimeType": self._guess_mime_type(input_path), "data": img64}},
|
||
{"text": prompt},
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": self._build_generation_config(
|
||
aspect_ratio=aspect_ratio,
|
||
image_size=image_size,
|
||
person_generation=person_generation,
|
||
thinking_level=thinking_level,
|
||
),
|
||
}
|
||
|
||
metrics_emit("gemini_request", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native")
|
||
timeout = aiohttp.ClientTimeout(total=300, connect=30)
|
||
|
||
for attempt in range(1, 3):
|
||
try:
|
||
logger.info(f"Gemini 出图开始 attempt={attempt}/2 model={GEMINI_IMAGE_MODEL} input={input_path}")
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
logger.error(f"Gemini API请求失败 attempt={attempt}: {response.status} - {error_text}")
|
||
if attempt < 2:
|
||
await asyncio.sleep(attempt)
|
||
continue
|
||
return False, f"Gemini API请求失败: {response.status}", {}
|
||
result = await response.json()
|
||
|
||
image_bytes = self._extract_image_bytes(result)
|
||
data = self._save_image(image_bytes, output_path)
|
||
metrics_emit("gemini_success", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native")
|
||
try:
|
||
from utils.api_cost_tracker import record
|
||
|
||
record("gemini_extract", count=1)
|
||
except Exception:
|
||
pass
|
||
return True, "Gemini 出图完成", data
|
||
except Exception as e:
|
||
logger.error(f"Gemini 出图异常 attempt={attempt}: {e}")
|
||
if attempt < 2:
|
||
await asyncio.sleep(attempt)
|
||
continue
|
||
return False, f"Gemini 出图失败: {e}", {}
|
||
|
||
return False, "Gemini 出图失败", {}
|
||
|
||
async def correct_perspective(
|
||
self,
|
||
input_path: str,
|
||
output_path: str,
|
||
level: str = "mild",
|
||
) -> tuple[bool, str, dict]:
|
||
if level == "strong":
|
||
prompt = (
|
||
"这张图存在明显透视畸变。请把主体矫正为正面平铺视图,"
|
||
"所有边缘尽量水平或垂直,保持图案颜色和细节不变,只做几何矫正。"
|
||
)
|
||
else:
|
||
prompt = (
|
||
"这张图存在轻微透视畸变。请做轻度透视矫正,"
|
||
"消除斜拍拉伸感,保持图案颜色和细节不变。"
|
||
)
|
||
return await self.extract_pattern(
|
||
input_path=input_path,
|
||
output_path=output_path,
|
||
custom_prompt=prompt,
|
||
aspect_ratio="1:1",
|
||
)
|
||
|
||
async def cleanup(self):
|
||
return None
|
||
|
||
|
||
async def extract_pattern_v2(
|
||
input_path: str,
|
||
output_path: str,
|
||
custom_prompt: str = None,
|
||
aspect_ratio: str = "1:1",
|
||
) -> tuple[bool, str, dict]:
|
||
service = GeminiExtractV2Service()
|
||
try:
|
||
return await service.extract_pattern(input_path, output_path, custom_prompt, aspect_ratio)
|
||
finally:
|
||
await service.cleanup()
|