Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
469 lines
21 KiB
Python
469 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Gemini印花提取V2服务 - 使用服务
|
||
更经济的选择:1.4毛/张
|
||
"""
|
||
|
||
import asyncio
|
||
import aiohttp
|
||
import base64
|
||
import json
|
||
import re
|
||
import os
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import logging
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class GeminiExtractV2Service:
|
||
"""Gemini印花提取V2服务类 - 使用服务,更经济"""
|
||
|
||
SERVICE_NAME = "gemini_extract_v2"
|
||
|
||
# 多API配置,按优先级排序(便宜的优先使用)
|
||
API_CONFIGS = [
|
||
|
||
|
||
|
||
|
||
# {
|
||
# "name": "西风接口$0.003逆向",
|
||
# "api_key": "sk-UT9aupbfHI4rc3RUn8x5D8gN5Kk31yvLZQu8M3BCY5Nja1Fc",
|
||
# "api_url": "https://api.apiqik.com/v1/chat/completions" ,
|
||
# "api_model": "gemini-2.5-flash-image",
|
||
# "max_retries": 3, # 贵接口少重试
|
||
# "cost": "低"
|
||
# },
|
||
|
||
|
||
{
|
||
"name": "西风接口$0.014",
|
||
"api_key": "sk-uRuvzLfIHsc3BiHZ2cyebk0cYsZ8NR9rLL326QqXCKIy9EpK",
|
||
"api_url": "https://api.apiqik.online/v1beta/models",
|
||
"api_model": "gemini-2.5-flash-image", # 更稳定的模型
|
||
"max_retries": 2, # 贵接口少重试
|
||
"cost": "中",
|
||
"use_gemini_format": True # 使用Gemini原生API格式
|
||
},
|
||
|
||
{
|
||
"name": "最贵的",
|
||
"api_key": "sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
|
||
"api_url": "https://api.laozhang.ai/v1/chat/completions",
|
||
"api_model": "gemini-2.5-flash-image-preview",
|
||
"max_retries": 1,
|
||
"cost": "高"
|
||
}
|
||
]
|
||
|
||
# 默认提示词
|
||
DEFAULT_PROMPT = (
|
||
"提取印花图案,去褶皱并补齐缺失区域,生成完整清晰的平面图。"
|
||
"严格保持原图元素位置、颜色和细节,不要改风格。"
|
||
)
|
||
# DEFAULT_PROMPT = "生成图片,把衣服的图案展开起来做成数码印花印刷平面图。去掉皱褶,生成图案增强细节。排除衣服图案以外内容"
|
||
def __init__(self):
|
||
self.session = None
|
||
|
||
def image_to_base64(self, image_path: str) -> str:
|
||
"""将图片文件转换为base64编码字符串"""
|
||
try:
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"文件不存在: {image_path}")
|
||
return None
|
||
|
||
with open(image_path, "rb") as image_file:
|
||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||
return encoded_string
|
||
|
||
except Exception as e:
|
||
logger.error(f"Base64转换失败: {e}")
|
||
return None
|
||
|
||
async def extract_pattern(
|
||
self,
|
||
input_path: str,
|
||
output_path: str,
|
||
custom_prompt: str = None,
|
||
aspect_ratio: str = "1:1",
|
||
) -> tuple[bool, str, dict]:
|
||
"""
|
||
使用多API配置进行印花图案提取
|
||
|
||
Args:
|
||
input_path: 输入图片路径
|
||
output_path: 输出图片路径
|
||
custom_prompt: 自定义提示词
|
||
|
||
Returns:
|
||
tuple: (success, message, data)
|
||
"""
|
||
# 转换图片为Base64
|
||
img64 = self.image_to_base64(input_path)
|
||
if not img64:
|
||
return False, "图片编码失败", {}
|
||
|
||
# 使用自定义提示词或默认提示词
|
||
prompt = custom_prompt or self.DEFAULT_PROMPT
|
||
|
||
# 按优先级逐个尝试API配置
|
||
for config_index, config in enumerate(self.API_CONFIGS):
|
||
logger.info(f"尝试使用API: {config['name']} (成本: {config['cost']})")
|
||
|
||
# 对每个API配置进行重试
|
||
for attempt in range(config['max_retries']):
|
||
try:
|
||
logger.info(f"开始Gemini V2印花提取 - {config['name']} (第{attempt + 1}/{config['max_retries']}次尝试): {input_path}")
|
||
|
||
# 准备请求数据和URL
|
||
if config.get('use_gemini_format', False):
|
||
# Gemini原生API格式
|
||
api_url = f"{config['api_url']}/{config['api_model']}:generateContent?key={config['api_key']}"
|
||
headers = {
|
||
"Content-Type": "application/json"
|
||
}
|
||
image_config = {}
|
||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||
if aspect_ratio in valid_ratios:
|
||
image_config["aspectRatio"] = aspect_ratio
|
||
|
||
data = {
|
||
"contents": [
|
||
{
|
||
"role": "user",
|
||
"parts": [
|
||
{
|
||
"inlineData": {
|
||
"mimeType": "image/png",
|
||
"data": img64
|
||
}
|
||
},
|
||
{
|
||
"text": prompt
|
||
}
|
||
]
|
||
}
|
||
],
|
||
"generationConfig": {
|
||
"responseModalities": ["IMAGE"], # 只生成图片
|
||
**({"imageConfig": image_config} if image_config else {}),
|
||
}
|
||
}
|
||
else:
|
||
# OpenAI兼容格式
|
||
api_url = config['api_url']
|
||
headers = {
|
||
"Authorization": f"Bearer {config['api_key']}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
data = {
|
||
"model": config['api_model'],
|
||
"stream": False,
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": prompt
|
||
},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/png;base64,{img64}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
logger.info(f"正在请求{config['name']}服务 (第{attempt + 1}次)...")
|
||
|
||
# 发送异步请求
|
||
timeout = aiohttp.ClientTimeout(total=300, connect=30)
|
||
connector = aiohttp.TCPConnector(limit=10, limit_per_host=5)
|
||
|
||
try:
|
||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||
async with session.post(api_url, headers=headers, json=data) as response:
|
||
if response.status != 200:
|
||
error_text = await response.text()
|
||
logger.error(f"{config['name']} API请求失败 (第{attempt + 1}次): {response.status} - {error_text}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
result = await response.json()
|
||
|
||
except (aiohttp.ClientError, asyncio.TimeoutError, AssertionError) as e:
|
||
logger.error(f"{config['name']} 网络连接错误 (第{attempt + 1}次): {str(e)}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 网络重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
logger.info(f"{config['name']} 服务请求成功 (第{attempt + 1}次),正在处理响应...")
|
||
|
||
# 处理API响应并提取图片
|
||
success, message, data = await self._process_api_response(result, output_path, config['name'], config)
|
||
|
||
if success:
|
||
logger.info(f"使用 {config['name']} 成功完成印花提取")
|
||
return True, f"Gemini V2印花提取完成 - 使用{config['name']}", data
|
||
else:
|
||
logger.warning(f"{config['name']} 响应处理失败: {message}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"{config['name']} API调用异常 (第{attempt + 1}次): {str(e)}")
|
||
|
||
# 如果是当前API配置的最后一次重试
|
||
if attempt == config['max_retries'] - 1:
|
||
logger.warning(f"{config['name']} 异常重试已用完,切换到下一个API配置")
|
||
break
|
||
|
||
# 当前API配置内部重试
|
||
base_wait_time = 2
|
||
wait_time = base_wait_time * (attempt + 1)
|
||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||
await asyncio.sleep(wait_time)
|
||
continue
|
||
|
||
# 所有API配置都尝试过了,返回失败
|
||
return False, "所有API配置都已尝试失败", {}
|
||
|
||
async def _process_api_response(self, result: dict, output_path: str, api_name: str, config: dict) -> tuple[bool, str, dict]:
|
||
"""处理API响应并提取图片"""
|
||
try:
|
||
# 根据API格式提取内容
|
||
if config.get('use_gemini_format', False):
|
||
# Gemini原生API格式: candidates[0].content.parts[0]
|
||
content_parts = result['candidates'][0]['content']['parts']
|
||
|
||
# 查找包含图片数据的part
|
||
image_data = None
|
||
for part in content_parts:
|
||
# 注意:响应中使用驼峰命名 inlineData
|
||
if 'inlineData' in part:
|
||
# 提取Base64图片数据
|
||
base64_data = part['inlineData']['data']
|
||
logger.info(f"{api_name} 找到Gemini格式的inlineData图片")
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||
return False, f"Base64解码失败: {e}", {}
|
||
|
||
if not image_data:
|
||
logger.error(f"{api_name} 在Gemini响应中未找到图片数据")
|
||
return False, "未找到图片数据", {}
|
||
|
||
# 直接保存图片
|
||
return await self._save_image(image_data, output_path, api_name)
|
||
|
||
else:
|
||
# OpenAI兼容格式: choices[0].message.content
|
||
content = result['choices'][0]['message']['content']
|
||
logger.info(f"{api_name} 收到内容: {content[:200]}...")
|
||
|
||
# 使用原有的URL/Base64提取逻辑
|
||
return await self._extract_and_save_image(content, output_path, api_name)
|
||
|
||
except KeyError as e:
|
||
logger.error(f"{api_name} 响应格式不正确,缺少字段: {e}")
|
||
logger.error(f"响应内容: {json.dumps(result, ensure_ascii=False)[:500]}")
|
||
return False, f"响应格式错误: {e}", {}
|
||
except Exception as e:
|
||
logger.error(f"{api_name} 处理响应时发生异常: {e}")
|
||
return False, f"处理异常: {e}", {}
|
||
|
||
async def _save_image(self, image_data: bytes, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||
"""保存图片文件"""
|
||
try:
|
||
output_dir = os.path.dirname(output_path)
|
||
if output_dir:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
with open(output_path, 'wb') as f:
|
||
f.write(image_data)
|
||
|
||
logger.info(f"{api_name} 图片已保存到: {output_path}")
|
||
|
||
# 验证保存的图片
|
||
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||
file_size = os.path.getsize(output_path)
|
||
logger.info(f"{api_name} 图片保存成功,文件大小: {file_size} bytes")
|
||
|
||
return True, f"{api_name} 印花提取完成", {
|
||
'output_path': output_path,
|
||
'file_size': file_size,
|
||
'api_used': api_name
|
||
}
|
||
else:
|
||
logger.error(f"{api_name} 保存的图片文件无效")
|
||
return False, "保存的图片文件无效", {}
|
||
|
||
except Exception as e:
|
||
logger.error(f"{api_name} 保存图片时发生错误: {e}")
|
||
return False, f"保存图片失败: {e}", {}
|
||
|
||
async def _extract_and_save_image(self, content: str, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||
"""从响应内容中提取并保存图片(URL或Base64格式)"""
|
||
# 查找和处理图片数据
|
||
image_data = None
|
||
|
||
# 方法1: 查找URL链接 (优先检查URL格式)
|
||
url_match = re.search(r'https?://[^\s\)]+\.(?:png|jpg|jpeg|gif|webp)', content)
|
||
if url_match:
|
||
image_url = url_match.group(0)
|
||
logger.info(f"{api_name} 找到图片URL: {image_url}")
|
||
|
||
# 图片下载重试机制
|
||
download_retries = 3
|
||
for download_attempt in range(download_retries):
|
||
try:
|
||
logger.info(f"{api_name} 开始下载图片 (第{download_attempt + 1}/{download_retries}次尝试): {image_url}")
|
||
|
||
# 异步下载图片,增加超时时间
|
||
timeout = aiohttp.ClientTimeout(total=300, connect=60)
|
||
connector = aiohttp.TCPConnector(limit=5, limit_per_host=2)
|
||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
|
||
|
||
async with aiohttp.ClientSession(
|
||
timeout=timeout,
|
||
connector=connector,
|
||
headers=headers
|
||
) as download_session:
|
||
logger.info(f"{api_name} 正在发送HTTP请求...")
|
||
async with download_session.get(image_url) as img_response:
|
||
logger.info(f"{api_name} 收到HTTP响应: {img_response.status}")
|
||
if img_response.status == 200:
|
||
image_data = await img_response.read()
|
||
logger.info(f"{api_name} 图片下载成功,大小: {len(image_data)} bytes")
|
||
break # 成功则跳出重试循环
|
||
else:
|
||
logger.error(f"{api_name} 图片下载失败,HTTP状态码: {img_response.status}")
|
||
if download_attempt == download_retries - 1:
|
||
return False, "图片下载失败", {}
|
||
else:
|
||
await asyncio.sleep(2)
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"{api_name} 下载图片时发生异常 (第{download_attempt + 1}次): {type(e).__name__}: {str(e)}")
|
||
if download_attempt == download_retries - 1:
|
||
return False, f"图片下载异常: {str(e)}", {}
|
||
else:
|
||
await asyncio.sleep(2)
|
||
continue
|
||
|
||
else:
|
||
# 方法2: 查找标准格式 data:image/type;base64,data
|
||
base64_match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', content)
|
||
|
||
if base64_match:
|
||
base64_data = base64_match.group(1)
|
||
logger.info(f"{api_name} 找到标准格式的Base64数据")
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
except Exception as e:
|
||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||
return False, f"Base64解码失败: {e}", {}
|
||
else:
|
||
# 方法3: 查找纯Base64数据(长字符串)
|
||
base64_match = re.search(r'([A-Za-z0-9+/=]{100,})', content)
|
||
if base64_match:
|
||
base64_data = base64_match.group(1)
|
||
logger.info(f"{api_name} 找到纯Base64数据")
|
||
try:
|
||
image_data = base64.b64decode(base64_data)
|
||
except Exception as e:
|
||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||
return False, f"Base64解码失败: {e}", {}
|
||
else:
|
||
logger.error(f"{api_name} 在响应中未找到图片数据")
|
||
return False, "未找到图片数据", {}
|
||
|
||
# 检查图片数据
|
||
if not image_data:
|
||
logger.error(f"{api_name} 图片数据为空")
|
||
return False, "图片数据为空", {}
|
||
|
||
# 保存图片
|
||
return await self._save_image(image_data, output_path, api_name)
|
||
|
||
async def cleanup(self):
|
||
"""清理资源"""
|
||
if self.session and not self.session.closed:
|
||
await self.session.close()
|
||
|
||
# 便捷函数
|
||
async def extract_pattern_v2(input_path: str, output_path: str, custom_prompt: str = None) -> tuple[bool, str, dict]:
|
||
"""
|
||
Gemini V2印花提取便捷函数
|
||
"""
|
||
service = GeminiExtractV2Service()
|
||
try:
|
||
return await service.extract_pattern(input_path, output_path, custom_prompt)
|
||
finally:
|
||
await service.cleanup()
|
||
|
||
if __name__ == "__main__":
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="[%(asctime)s] %(levelname)s %(name)s: %(message)s",
|
||
datefmt="%H:%M:%S",
|
||
)
|
||
|
||
import asyncio
|
||
|
||
async def test():
|
||
service = GeminiExtractV2Service()
|
||
|
||
input_path = "image.png"
|
||
output_path = f"image_output_{int(time.time())}.png"
|
||
|
||
success, message, data = await service.extract_pattern(input_path, output_path)
|
||
|
||
print(f"结果: {success}")
|
||
print(f"消息: {message}")
|
||
print(f"数据: {data}")
|
||
|
||
await service.cleanup()
|
||
|
||
asyncio.run(test())
|