#!/usr/bin/env python3 """Gemini 出图服务。固定走老张 Gemini 原生出图接口。""" import asyncio import base64 import logging import mimetypes import os from typing import Dict import aiohttp from dotenv import load_dotenv from utils.metrics_tracker import emit as metrics_emit from utils.service_base import BaseService logger = logging.getLogger(__name__) load_dotenv() GEMINI_API_KEY = os.getenv( "GEMINI_API_KEY", "sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8", ) GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview") GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "2K") GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "") class GeminiExtractV2Service(BaseService): """固定单接口的 Gemini 出图服务。""" SERVICE_NAME = "gemini_extract_v2" API_BASE_URL = "https://api.laozhang.ai/v1beta/models" DEFAULT_PROMPT = ( "提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富," "严格按照原图的元素位置生成平面的印花图,不要相似的,相似度要100%,生成高质量的印刷图" ) def __init__(self): super().__init__(name=self.SERVICE_NAME) @staticmethod def _image_to_base64(image_path: str) -> str: if not os.path.exists(image_path): logger.error(f"文件不存在: {image_path}") return "" try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") except Exception as e: logger.error(f"Base64转换失败: {e}") return "" @staticmethod def _guess_mime_type(image_path: str) -> str: mime_type, _ = mimetypes.guess_type(str(image_path)) return mime_type or "image/png" @staticmethod def _build_generation_config( aspect_ratio: str, image_size: str, person_generation: str, thinking_level: str, ) -> Dict: valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"} valid_sizes = {"1K", "2K", "4K"} image_config = {} if aspect_ratio in valid_ratios: image_config["aspectRatio"] = aspect_ratio size_val = (image_size or GEMINI_IMAGE_SIZE or "").upper().strip() if size_val in valid_sizes: image_config["imageSize"] = size_val person_val = (person_generation or GEMINI_PERSON_GENERATION or "").strip() if person_val: image_config["personGeneration"] = person_val generation_config = {"responseModalities": ["IMAGE"]} if image_config: generation_config["imageConfig"] = image_config return generation_config @staticmethod def _extract_image_bytes(result: Dict) -> bytes: candidates = result.get("candidates") or [] if not candidates: raise ValueError("响应缺少 candidates") parts = ((candidates[0] or {}).get("content") or {}).get("parts") or [] for part in parts: inline_data = part.get("inlineData") or {} encoded = inline_data.get("data") if encoded: return base64.b64decode(encoded) finish_reason = candidates[0].get("finishReason") or "" if finish_reason == "NO_IMAGE": raise ValueError("模型未返回图片(NO_IMAGE)") raise ValueError("响应中未找到 inlineData 图片") @staticmethod def _save_image(image_data: bytes, output_path: str) -> Dict: os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "wb") as f: f.write(image_data) file_size = os.path.getsize(output_path) return { "output_path": output_path, "file_size": file_size, "api_used": "laozhang_gemini_native", } async def extract_pattern( self, input_path: str, output_path: str, custom_prompt: str = None, aspect_ratio: str = "1:1", image_size: str = "", person_generation: str = "", thinking_level: str = "", ) -> tuple[bool, str, dict]: img64 = self._image_to_base64(input_path) if not img64: return False, "图片编码失败", {} prompt = str(custom_prompt or self.DEFAULT_PROMPT).strip() api_url = f"{self.API_BASE_URL}/{GEMINI_IMAGE_MODEL}:generateContent" headers = { "Authorization": f"Bearer {GEMINI_API_KEY}", "Content-Type": "application/json", } payload = { "contents": [ { "parts": [ {"inlineData": {"mimeType": self._guess_mime_type(input_path), "data": img64}}, {"text": prompt}, ] } ], "generationConfig": self._build_generation_config( aspect_ratio=aspect_ratio, image_size=image_size, person_generation=person_generation, thinking_level=thinking_level, ), } metrics_emit("gemini_request", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native") timeout = aiohttp.ClientTimeout(total=300, connect=30) for attempt in range(1, 3): try: logger.info(f"Gemini 出图开始 attempt={attempt}/2 model={GEMINI_IMAGE_MODEL} input={input_path}") async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(api_url, headers=headers, json=payload) as response: if response.status != 200: error_text = await response.text() logger.error(f"Gemini API请求失败 attempt={attempt}: {response.status} - {error_text}") if attempt < 2: await asyncio.sleep(attempt) continue return False, f"Gemini API请求失败: {response.status}", {} result = await response.json() image_bytes = self._extract_image_bytes(result) data = self._save_image(image_bytes, output_path) metrics_emit("gemini_success", model=GEMINI_IMAGE_MODEL, provider="laozhang_gemini_native") try: from utils.api_cost_tracker import record record("gemini_extract", count=1) except Exception: pass return True, "Gemini 出图完成", data except Exception as e: logger.error(f"Gemini 出图异常 attempt={attempt}: {e}") if attempt < 2: await asyncio.sleep(attempt) continue return False, f"Gemini 出图失败: {e}", {} return False, "Gemini 出图失败", {} async def correct_perspective( self, input_path: str, output_path: str, level: str = "mild", ) -> tuple[bool, str, dict]: if level == "strong": prompt = ( "这张图存在明显透视畸变。请把主体矫正为正面平铺视图," "所有边缘尽量水平或垂直,保持图案颜色和细节不变,只做几何矫正。" ) else: prompt = ( "这张图存在轻微透视畸变。请做轻度透视矫正," "消除斜拍拉伸感,保持图案颜色和细节不变。" ) return await self.extract_pattern( input_path=input_path, output_path=output_path, custom_prompt=prompt, aspect_ratio="1:1", ) async def cleanup(self): return None async def extract_pattern_v2( input_path: str, output_path: str, custom_prompt: str = None, aspect_ratio: str = "1:1", ) -> tuple[bool, str, dict]: service = GeminiExtractV2Service() try: return await service.extract_pattern(input_path, output_path, custom_prompt, aspect_ratio) finally: await service.cleanup()