568 lines
26 KiB
Python
Executable File
568 lines
26 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
Gemini印花提取V2服务 - 使用服务
|
||
更经济的选择:1.4毛/张
|
||
"""
|
||
|
||
import asyncio
|
||
import aiohttp
|
||
import base64
|
||
import json
|
||
import re
|
||
import os
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import logging
|
||
from dotenv import load_dotenv
|
||
from utils.metrics_tracker import emit as metrics_emit
|
||
|
||
|
||
from utils.service_base import BaseService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
load_dotenv()
|
||
GEMINI_IMAGE_MODEL = os.getenv("GEMINI_IMAGE_MODEL", "gemini-3.1-flash-image-preview")
|
||
GEMINI_IMAGE_FALLBACK_MODEL = os.getenv("GEMINI_IMAGE_FALLBACK_MODEL", "gemini-2.5-flash-image")
|
||
GEMINI_IMAGE_SIZE = os.getenv("GEMINI_IMAGE_SIZE", "1K")
|
||
GEMINI_THINKING_LEVEL = os.getenv("GEMINI_THINKING_LEVEL", "MINIMAL")
|
||
GEMINI_PERSON_GENERATION = os.getenv("GEMINI_PERSON_GENERATION", "")
|
||
|
||
|
||
class GeminiExtractV2Service(BaseService):
|
||
"""Gemini印花提取V2服务类 - 使用服务,更经济"""
|
||
|
||
SERVICE_NAME = "gemini_extract_v2"
|
||
|
||
# 多API配置,按优先级排序(便宜的优先使用)
|
||
API_CONFIGS = [
|
||
|
||
|
||
|
||
|
||
# {
|
||
# "name": "西风接口$0.003逆向",
|
||
# "api_key": "sk-UT9aupbfHI4rc3RUn8x5D8gN5Kk31yvLZQu8M3BCY5Nja1Fc",
|
||
# "api_url": "https://api.apiqik.com/v1/chat/completions" ,
|
||
# "api_model": "gemini-2.5-flash-image",
|
||
# "max_retries": 3, # 贵接口少重试
|
||
# "cost": "低"
|
||
# },
|
||
|
||
|
||
{
|
||
"name": "西风接口$0.014",
|
||
"api_key": "sk-uRuvzLfIHsc3BiHZ2cyebk0cYsZ8NR9rLL326QqXCKIy9EpK",
|
||
"api_url": "https://api.apiqik.online/v1beta/models",
|
||
"api_model": GEMINI_IMAGE_MODEL,
|
||
"max_retries": 2,
|
||
"cost": "中",
|
||
"use_gemini_format": True # 使用Gemini原生API格式
|
||
},
|
||
{
|
||
"name": "西风接口Fallback",
|
||
"api_key": "sk-uRuvzLfIHsc3BiHZ2cyebk0cYsZ8NR9rLL326QqXCKIy9EpK",
|
||
"api_url": "https://api.apiqik.online/v1beta/models",
|
||
"api_model": GEMINI_IMAGE_FALLBACK_MODEL,
|
||
"max_retries": 1,
|
||
"cost": "中",
|
||
"use_gemini_format": True
|
||
},
|
||
|
||
{
|
||
"name": "最贵的",
|
||
"api_key": "sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
|
||
"api_url": "https://api.laozhang.ai/v1/chat/completions",
|
||
"api_model": GEMINI_IMAGE_MODEL,
|
||
"max_retries": 1,
|
||
"cost": "高"
|
||
}
|
||
]
|
||
|
||
# 默认提示词
|
||
DEFAULT_PROMPT = "提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富,严格按照原图的元素位置生成平面的印花图,不要相似的,相似度要100%,生成高质量的印刷图"
|
||
# DEFAULT_PROMPT = "生成图片,把衣服的图案展开起来做成数码印花印刷平面图。去掉皱褶,生成图案增强细节。排除衣服图案以外内容"
|
||
def __init__(self):
|
||
super().__init__(name="gemini_extract_v2")
|
||
self.session = None
|
||
|
||
def image_to_base64(self, image_path: str) -> str:
|
||
"""将图片文件转换为base64编码字符串"""
|
||
try:
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"文件不存在: {image_path}")
|
||
return None
|
||
|
||
with open(image_path, "rb") as image_file:
|
||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||
return encoded_string
|
||
|
||
except Exception as e:
|
||
logger.error(f"Base64转换失败: {e}")
|
||
return None
|
||
|
||
async def extract_pattern(
|
||
self,
|
||
input_path: str,
|
||
output_path: str,
|
||
custom_prompt: str = None,
|
||
aspect_ratio: str = "1:1",
|
||
image_size: str = "",
|
||
person_generation: str = "",
|
||
thinking_level: str = "",
|
||
) -> tuple[bool, str, dict]:
|
||
"""
|
||
使用多API配置进行印花图案提取
|
||
|
||
Args:
|
||
input_path: 输入图片路径
|
||
output_path: 输出图片路径
|
||
custom_prompt: 自定义提示词
|
||
|
||
Returns:
|
||
tuple: (success, message, data)
|
||
"""
|
||
# 转换图片为Base64
|
||
img64 = self.image_to_base64(input_path)
|
||
if not img64:
|
||
return False, "图片编码失败", {}
|
||
|
||
# 使用自定义提示词或默认提示词
|
||
prompt = custom_prompt or self.DEFAULT_PROMPT
|
||
|
||
# 按优先级逐个尝试API配置
|
||
for config_index, config in enumerate(self.API_CONFIGS):
|
||
logger.info(f"尝试使用API: {config['name']} (成本: {config['cost']})")
|
||
metrics_emit("gemini_request", model=config.get("api_model", ""), provider=config.get("name", ""))
|
||
|
||
# 对每个API配置进行重试
|
||
for attempt in range(config['max_retries']):
|
||
try:
|
||
logger.info(f"开始Gemini V2印花提取 - {config['name']} (第{attempt + 1}/{config['max_retries']}次尝试): {input_path}")
|
||
|
||
# 准备请求数据和URL
|
||
if config.get('use_gemini_format', False):
|
||
# Gemini原生API格式
|
||
api_url = f"{config['api_url']}/{config['api_model']}:generateContent?key={config['api_key']}"
|
||
headers = {
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
# 有效比例列表(Auto 不传 aspectRatio)
|
||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||
valid_sizes = {"1K", "2K", "4K"}
|
||
valid_thinking = {"MINIMAL", "LOW", "MEDIUM", "HIGH"}
|
||
image_config = {}
|
||
if aspect_ratio in valid_ratios:
|
||
image_config["aspectRatio"] = aspect_ratio
|
||
size_val = (image_size or GEMINI_IMAGE_SIZE or "").upper().strip()
|
||
if size_val in valid_sizes:
|
||
image_config["imageSize"] = size_val
|
||
person_val = (person_generation or GEMINI_PERSON_GENERATION or "").strip()
|
||
if person_val:
|
||
# 中转接口若支持该字段会生效;不设置时不发送,保证兼容
|
||
image_config["personGeneration"] = person_val
|
||
thinking_val = (thinking_level or GEMINI_THINKING_LEVEL or "").upper().strip()
|
||
thinking_config = {}
|
||
if thinking_val in valid_thinking:
|
||
thinking_config["thinkingLevel"] = thinking_val
|
||
|
||
data = {
|
||
"contents": [
|
||
{
|
||
"role": "user",
|
||
"parts": [
|
||
{
|
||
"inlineData": {
|
||
"mimeType": "image/jpeg",
|
||
"data": img64
|
||
}
|
||
},
|
||
{
|
||
"text": prompt
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"responseModalities": ["IMAGE"],
|
||
**({"imageConfig": image_config} if image_config else {}),
|
||
**({"thinkingConfig": thinking_config} if thinking_config else {}),
|
||
}
|
||
}
|
||
logger.info(
|
||
f"Gemini 生成配置: 比例={aspect_ratio} 尺寸={image_config.get('imageSize', '默认')} "
|
||
f"person={image_config.get('personGeneration', '默认')} thinking={thinking_config.get('thinkingLevel', '默认')}"
|
||
)
|
||
else:
|
||
# OpenAI兼容格式
|
||
api_url = config['api_url']
|
||
headers = {
|
||
"Authorization": f"Bearer {config['api_key']}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
data = {
|
||
"model": config['api_model'],
|
||
"stream": False,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": prompt
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/png;base64,{img64}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
logger.info(f"正在请求{config['name']}服务 (第{attempt + 1}次)...")
|
||
|
||
# 发送异步请求
|
||
timeout = aiohttp.ClientTimeout(total=300, connect=30)
|
||
connector = aiohttp.TCPConnector(limit=10, limit_per_host=5)
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(api_url, headers=headers, json=data) as response:
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
logger.error(f"{config['name']} API请求失败 (第{attempt + 1}次): {response.status} - {error_text}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
result = await response.json()
|
||
# Gemini 偶发只返回文本不返回图片:NO_IMAGE 时快速重试/降级
|
||
if config.get('use_gemini_format', False):
|
||
finish_reason = ""
|
||
try:
|
||
finish_reason = (
|
||
(result.get("candidates") or [{}])[0].get("finishReason", "")
|
||
)
|
||
except Exception:
|
||
finish_reason = ""
|
||
if finish_reason == "NO_IMAGE":
|
||
logger.warning(
|
||
f"{config['name']} 返回 NO_IMAGE (模型={config.get('api_model')}),第{attempt + 1}次"
|
||
)
|
||
metrics_emit("gemini_no_image", model=config.get("api_model", ""), provider=config.get("name", ""))
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} NO_IMAGE 重试已用完,切换下一个配置")
|
||
break
|
||
await asyncio.sleep(1 + attempt)
|
||
continue
|
||
|
||
except (aiohttp.ClientError, asyncio.TimeoutError, AssertionError) as e:
|
||
logger.error(f"{config['name']} 网络连接错误 (第{attempt + 1}次): {str(e)}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 网络重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
logger.info(f"{config['name']} 服务请求成功 (第{attempt + 1}次),正在处理响应...")
|
||
|
||
# 处理API响应并提取图片
|
||
success, message, data = await self._process_api_response(result, output_path, config['name'], config)
|
||
|
||
if success:
|
||
logger.info(f"使用 {config['name']} 成功完成印花提取")
|
||
metrics_emit("gemini_success", model=config.get("api_model", ""), provider=config.get("name", ""))
|
||
try:
|
||
from utils.api_cost_tracker import record
|
||
record("gemini_extract", count=1)
|
||
except Exception:
|
||
pass
|
||
return True, f"Gemini V2印花提取完成 - 使用{config['name']}", data
|
||
else:
|
||
logger.warning(f"{config['name']} 响应处理失败: {message}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"{config['name']} API调用异常 (第{attempt + 1}次): {str(e)}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 异常重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
# 所有API配置都尝试过了,返回失败
|
||
return False, "所有API配置都已尝试失败", {}
|
||
|
||
async def _process_api_response(self, result: dict, output_path: str, api_name: str, config: dict) -> tuple[bool, str, dict]:
|
||
"""处理API响应并提取图片"""
|
||
try:
|
||
# 根据API格式提取内容
|
||
if config.get('use_gemini_format', False):
|
||
# Gemini原生API格式: candidates[0].content.parts[0]
|
||
content_parts = result['candidates'][0]['content']['parts']
|
||
|
||
# 查找包含图片数据的part
|
||
image_data = None
|
||
for part in content_parts:
|
||
# 注意:响应中使用驼峰命名 inlineData
|
||
if 'inlineData' in part:
|
||
# 提取Base64图片数据
|
||
base64_data = part['inlineData']['data']
|
||
logger.info(f"{api_name} 找到Gemini格式的inlineData图片")
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||
return False, f"Base64解码失败: {e}", {}
|
||
|
||
if not image_data:
|
||
logger.error(f"{api_name} 在Gemini响应中未找到图片数据")
|
||
return False, "未找到图片数据", {}
|
||
|
||
# 直接保存图片
|
||
return await self._save_image(image_data, output_path, api_name)
|
||
|
||
else:
|
||
# OpenAI兼容格式: choices[0].message.content
|
||
content = result['choices'][0]['message']['content']
|
||
logger.info(f"{api_name} 收到内容: {content[:200]}...")
|
||
|
||
# 使用原有的URL/Base64提取逻辑
|
||
return await self._extract_and_save_image(content, output_path, api_name)
|
||
|
||
except KeyError as e:
|
||
logger.error(f"{api_name} 响应格式不正确,缺少字段: {e}")
|
||
logger.error(f"响应内容: {json.dumps(result, ensure_ascii=False)[:500]}")
|
||
return False, f"响应格式错误: {e}", {}
|
||
except Exception as e:
|
||
logger.error(f"{api_name} 处理响应时发生异常: {e}")
|
||
return False, f"处理异常: {e}", {}
|
||
|
||
async def _save_image(self, image_data: bytes, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||
"""保存图片文件"""
|
||
try:
|
||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||
|
||
with open(output_path, 'wb') as f:
|
||
f.write(image_data)
|
||
|
||
logger.info(f"{api_name} 图片已保存到: {output_path}")
|
||
|
||
# 验证保存的图片
|
||
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||
file_size = os.path.getsize(output_path)
|
||
logger.info(f"{api_name} 图片保存成功,文件大小: {file_size} bytes")
|
||
|
||
return True, f"{api_name} 印花提取完成", {
|
||
'output_path': output_path,
|
||
'file_size': file_size,
|
||
'api_used': api_name
|
||
}
|
||
else:
|
||
logger.error(f"{api_name} 保存的图片文件无效")
|
||
return False, "保存的图片文件无效", {}
|
||
|
||
except Exception as e:
|
||
logger.error(f"{api_name} 保存图片时发生错误: {e}")
|
||
return False, f"保存图片失败: {e}", {}
|
||
|
||
async def _extract_and_save_image(self, content: str, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||
"""从响应内容中提取并保存图片(URL或Base64格式)"""
|
||
# 查找和处理图片数据
|
||
image_data = None
|
||
|
||
# 方法1: 查找URL链接 (优先检查URL格式)
|
||
url_match = re.search(r'https?://[^\s\)]+\.(?:png|jpg|jpeg|gif|webp)', content)
|
||
if url_match:
|
||
image_url = url_match.group(0)
|
||
logger.info(f"{api_name} 找到图片URL: {image_url}")
|
||
|
||
# 图片下载重试机制
|
||
download_retries = 3
|
||
for download_attempt in range(download_retries):
|
||
try:
|
||
logger.info(f"{api_name} 开始下载图片 (第{download_attempt + 1}/{download_retries}次尝试): {image_url}")
|
||
|
||
# 异步下载图片,增加超时时间
|
||
timeout = aiohttp.ClientTimeout(total=300, connect=60)
|
||
connector = aiohttp.TCPConnector(limit=5, limit_per_host=2)
|
||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
|
||
|
||
async with aiohttp.ClientSession(
|
||
timeout=timeout,
|
||
connector=connector,
|
||
headers=headers
|
||
) as download_session:
|
||
logger.info(f"{api_name} 正在发送HTTP请求...")
|
||
async with download_session.get(image_url) as img_response:
|
||
logger.info(f"{api_name} 收到HTTP响应: {img_response.status}")
|
||
if img_response.status == 200:
|
||
image_data = await img_response.read()
|
||
logger.info(f"{api_name} 图片下载成功,大小: {len(image_data)} bytes")
|
||
break # 成功则跳出重试循环
|
||
else:
|
||
logger.error(f"{api_name} 图片下载失败,HTTP状态码: {img_response.status}")
|
||
if download_attempt == download_retries - 1:
|
||
return False, "图片下载失败", {}
|
||
else:
|
||
await asyncio.sleep(2)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"{api_name} 下载图片时发生异常 (第{download_attempt + 1}次): {type(e).__name__}: {str(e)}")
|
||
if download_attempt == download_retries - 1:
|
||
return False, f"图片下载异常: {str(e)}", {}
|
||
else:
|
||
await asyncio.sleep(2)
|
||
continue
|
||
|
||
else:
|
||
# 方法2: 查找标准格式 data:image/type;base64,data
|
||
base64_match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', content)
|
||
|
||
if base64_match:
|
||
base64_data = base64_match.group(1)
|
||
logger.info(f"{api_name} 找到标准格式的Base64数据")
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
except Exception as e:
|
||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||
return False, f"Base64解码失败: {e}", {}
|
||
else:
|
||
# 方法3: 查找纯Base64数据(长字符串)
|
||
base64_match = re.search(r'([A-Za-z0-9+/=]{100,})', content)
|
||
if base64_match:
|
||
base64_data = base64_match.group(1)
|
||
logger.info(f"{api_name} 找到纯Base64数据")
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
except Exception as e:
|
||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||
return False, f"Base64解码失败: {e}", {}
|
||
else:
|
||
logger.error(f"{api_name} 在响应中未找到图片数据")
|
||
return False, "未找到图片数据", {}
|
||
|
||
# 检查图片数据
|
||
if not image_data:
|
||
logger.error(f"{api_name} 图片数据为空")
|
||
return False, "图片数据为空", {}
|
||
|
||
# 保存图片
|
||
return await self._save_image(image_data, output_path, api_name)
|
||
|
||
async def correct_perspective(
|
||
self,
|
||
input_path: str,
|
||
output_path: str,
|
||
level: str = "mild",
|
||
) -> tuple[bool, str, dict]:
|
||
"""
|
||
透视矫正:先把有透视畸变的图还原为正面平铺视图,再做后续处理。
|
||
|
||
Args:
|
||
input_path: 本地图片路径
|
||
output_path: 矫正后输出路径
|
||
level: "mild" 或 "strong"
|
||
"""
|
||
if level == "strong":
|
||
prompt = (
|
||
"这张图存在明显透视畸变(俯拍/斜拍/贴墙)。"
|
||
"请对图片进行透视矫正:将主体变换为正面平铺视图,"
|
||
"使所有边缘变成水平或垂直,去除梯形形变,"
|
||
"保持图案颜色和细节完全不变,只矫正几何形状,输出矫正后的完整图片。"
|
||
)
|
||
else:
|
||
prompt = (
|
||
"这张图存在轻微透视畸变(衣物悬挂/桌面斜拍)。"
|
||
"请做轻度透视矫正:将主体调整为尽量正视角,"
|
||
"消除轻微的梯形拉伸感,保持图案颜色和细节不变,输出矫正后的图片。"
|
||
)
|
||
|
||
# 透视矫正使用 1:1 比例避免比例失真
|
||
return await self.extract_pattern(
|
||
input_path=input_path,
|
||
output_path=output_path,
|
||
custom_prompt=prompt,
|
||
aspect_ratio="1:1",
|
||
)
|
||
|
||
async def cleanup(self):
|
||
"""清理资源"""
|
||
if self.session and not self.session.closed:
|
||
await self.session.close()
|
||
|
||
# 便捷函数
|
||
async def extract_pattern_v2(
|
||
input_path: str,
|
||
output_path: str,
|
||
custom_prompt: str = None,
|
||
aspect_ratio: str = "1:1",
|
||
) -> tuple[bool, str, dict]:
|
||
"""Gemini V2印花提取便捷函数"""
|
||
service = GeminiExtractV2Service()
|
||
try:
|
||
return await service.extract_pattern(input_path, output_path, custom_prompt, aspect_ratio)
|
||
finally:
|
||
await service.cleanup()
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码
|
||
import asyncio
|
||
|
||
async def test():
|
||
service = GeminiExtractV2Service()
|
||
|
||
input_path = "F:/api/134.png"
|
||
output_path = "test_output_v2.png"
|
||
|
||
success, message, data = await service.extract_pattern(input_path, output_path)
|
||
|
||
print(f"结果: {success}")
|
||
print(f"消息: {message}")
|
||
print(f"数据: {data}")
|
||
|
||
await service.cleanup()
|
||
|
||
asyncio.run(test())
|