init
This commit is contained in:
0
services/__init__.py
Normal file
0
services/__init__.py
Normal file
BIN
services/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
services/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
services/__pycache__/service_gemini.cpython-310.pyc
Normal file
BIN
services/__pycache__/service_gemini.cpython-310.pyc
Normal file
Binary file not shown.
BIN
services/__pycache__/service_meitu.cpython-310.pyc
Normal file
BIN
services/__pycache__/service_meitu.cpython-310.pyc
Normal file
Binary file not shown.
BIN
services/__pycache__/service_vectorizer.cpython-310.pyc
Normal file
BIN
services/__pycache__/service_vectorizer.cpython-310.pyc
Normal file
Binary file not shown.
512
services/service_gemini.py
Normal file
512
services/service_gemini.py
Normal file
@@ -0,0 +1,512 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Gemini印花提取V2服务 - 使用服务
|
||||
更经济的选择:1.4毛/张
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceBase:
|
||||
"""最小化基类,替代缺失的 utils.service_base"""
|
||||
pass
|
||||
|
||||
|
||||
class GeminiExtractV2Service(ServiceBase):
|
||||
"""Gemini印花提取V2服务类 - 使用服务,更经济"""
|
||||
|
||||
SERVICE_NAME = "gemini_extract_v2"
|
||||
|
||||
# 多API配置,按优先级排序(便宜的优先使用)
|
||||
API_CONFIGS = [
|
||||
|
||||
|
||||
|
||||
|
||||
# {
|
||||
# "name": "西风接口$0.003逆向",
|
||||
# "api_key": "sk-UT9aupbfHI4rc3RUn8x5D8gN5Kk31yvLZQu8M3BCY5Nja1Fc",
|
||||
# "api_url": "https://api.apiqik.com/v1/chat/completions" ,
|
||||
# "api_model": "gemini-2.5-flash-image",
|
||||
# "max_retries": 3, # 贵接口少重试
|
||||
# "cost": "低"
|
||||
# },
|
||||
|
||||
|
||||
{
|
||||
"name": "西风接口$0.014",
|
||||
"api_key": "sk-uRuvzLfIHsc3BiHZ2cyebk0cYsZ8NR9rLL326QqXCKIy9EpK",
|
||||
"api_url": "https://api.apiqik.online/v1beta/models",
|
||||
"api_model": "gemini-2.5-flash-image", # 更稳定的模型
|
||||
"max_retries": 2, # 贵接口少重试
|
||||
"cost": "中",
|
||||
"use_gemini_format": True # 使用Gemini原生API格式
|
||||
},
|
||||
|
||||
{
|
||||
"name": "最贵的",
|
||||
"api_key": "sk-8i7uYE0RtnQwDImV8a5f7014DcAb46F6BcEb72Df92218aC8",
|
||||
"api_url": "https://api.laozhang.ai/v1/chat/completions",
|
||||
"api_model": "gemini-2.5-flash-image-preview",
|
||||
"max_retries": 1,
|
||||
"cost": "高"
|
||||
}
|
||||
]
|
||||
|
||||
# 默认提示词
|
||||
DEFAULT_PROMPT = "提取印花图案,把褶皱移除。补齐缺失的部分,要生成完整,细节丰富,严格按照原图的元素位置生成平面的印花图,不要相似的,相似度要100%,生成高质量的印刷图"
|
||||
# DEFAULT_PROMPT = "生成图片,把衣服的图案展开起来做成数码印花印刷平面图。去掉皱褶,生成图案增强细节。排除衣服图案以外内容"
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.session = None
|
||||
|
||||
def image_to_base64(self, image_path: str) -> str:
|
||||
"""将图片文件转换为base64编码字符串"""
|
||||
try:
|
||||
if not os.path.exists(image_path):
|
||||
logger.error(f"文件不存在: {image_path}")
|
||||
return None
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
return encoded_string
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Base64转换失败: {e}")
|
||||
return None
|
||||
|
||||
async def extract_pattern(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
custom_prompt: str = None,
|
||||
aspect_ratio: str = "1:1",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""
|
||||
使用多API配置进行印花图案提取
|
||||
|
||||
Args:
|
||||
input_path: 输入图片路径
|
||||
output_path: 输出图片路径
|
||||
custom_prompt: 自定义提示词
|
||||
|
||||
Returns:
|
||||
tuple: (success, message, data)
|
||||
"""
|
||||
# 转换图片为Base64
|
||||
img64 = self.image_to_base64(input_path)
|
||||
if not img64:
|
||||
return False, "图片编码失败", {}
|
||||
|
||||
# 使用自定义提示词或默认提示词
|
||||
prompt = custom_prompt or self.DEFAULT_PROMPT
|
||||
|
||||
# 按优先级逐个尝试API配置
|
||||
for config_index, config in enumerate(self.API_CONFIGS):
|
||||
logger.info(f"尝试使用API: {config['name']} (成本: {config['cost']})")
|
||||
|
||||
# 对每个API配置进行重试
|
||||
for attempt in range(config['max_retries']):
|
||||
try:
|
||||
logger.info(f"开始Gemini V2印花提取 - {config['name']} (第{attempt + 1}/{config['max_retries']}次尝试): {input_path}")
|
||||
|
||||
# 准备请求数据和URL
|
||||
if config.get('use_gemini_format', False):
|
||||
# Gemini原生API格式
|
||||
api_url = f"{config['api_url']}/{config['api_model']}:generateContent?key={config['api_key']}"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 有效比例列表(Auto 不传 aspectRatio)
|
||||
valid_ratios = {"1:1", "9:16", "16:9", "3:4", "4:3", "3:2", "2:3", "5:4", "4:5"}
|
||||
image_config = {}
|
||||
if aspect_ratio in valid_ratios:
|
||||
image_config["aspectRatio"] = aspect_ratio
|
||||
|
||||
data = {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": img64
|
||||
}
|
||||
},
|
||||
{
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["IMAGE"],
|
||||
**({"imageConfig": image_config} if image_config else {}),
|
||||
}
|
||||
}
|
||||
logger.info(f"Gemini 生成配置: 比例={aspect_ratio} 格式=JPEG")
|
||||
else:
|
||||
# OpenAI兼容格式
|
||||
api_url = config['api_url']
|
||||
headers = {
|
||||
"Authorization": f"Bearer {config['api_key']}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": config['api_model'],
|
||||
"stream": False,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{img64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"正在请求{config['name']}服务 (第{attempt + 1}次)...")
|
||||
|
||||
# 发送异步请求
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=30)
|
||||
connector = aiohttp.TCPConnector(limit=10, limit_per_host=5)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"{config['name']} API请求失败 (第{attempt + 1}次): {response.status} - {error_text}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
result = await response.json()
|
||||
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, AssertionError) as e:
|
||||
logger.error(f"{config['name']} 网络连接错误 (第{attempt + 1}次): {str(e)}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 网络重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
logger.info(f"{config['name']} 服务请求成功 (第{attempt + 1}次),正在处理响应...")
|
||||
|
||||
# 处理API响应并提取图片
|
||||
success, message, data = await self._process_api_response(result, output_path, config['name'], config)
|
||||
|
||||
if success:
|
||||
logger.info(f"使用 {config['name']} 成功完成印花提取")
|
||||
try:
|
||||
from utils.api_cost_tracker import record
|
||||
record("gemini_extract", count=1)
|
||||
except Exception:
|
||||
pass
|
||||
return True, f"Gemini V2印花提取完成 - 使用{config['name']}", data
|
||||
else:
|
||||
logger.warning(f"{config['name']} 响应处理失败: {message}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 所有重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{config['name']} API调用异常 (第{attempt + 1}次): {str(e)}")
|
||||
|
||||
# 如果是当前API配置的最后一次重试
|
||||
if attempt == config['max_retries'] - 1:
|
||||
logger.warning(f"{config['name']} 异常重试已用完,切换到下一个API配置")
|
||||
break
|
||||
|
||||
# 当前API配置内部重试
|
||||
base_wait_time = 2
|
||||
wait_time = base_wait_time * (attempt + 1)
|
||||
logger.info(f"等待{wait_time}秒后重试{config['name']}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
# 所有API配置都尝试过了,返回失败
|
||||
return False, "所有API配置都已尝试失败", {}
|
||||
|
||||
async def _process_api_response(self, result: dict, output_path: str, api_name: str, config: dict) -> tuple[bool, str, dict]:
|
||||
"""处理API响应并提取图片"""
|
||||
try:
|
||||
# 根据API格式提取内容
|
||||
if config.get('use_gemini_format', False):
|
||||
# Gemini原生API格式: candidates[0].content.parts[0]
|
||||
content_parts = result['candidates'][0]['content']['parts']
|
||||
|
||||
# 查找包含图片数据的part
|
||||
image_data = None
|
||||
for part in content_parts:
|
||||
# 注意:响应中使用驼峰命名 inlineData
|
||||
if 'inlineData' in part:
|
||||
# 提取Base64图片数据
|
||||
base64_data = part['inlineData']['data']
|
||||
logger.info(f"{api_name} 找到Gemini格式的inlineData图片")
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||||
return False, f"Base64解码失败: {e}", {}
|
||||
|
||||
if not image_data:
|
||||
logger.error(f"{api_name} 在Gemini响应中未找到图片数据")
|
||||
return False, "未找到图片数据", {}
|
||||
|
||||
# 直接保存图片
|
||||
return await self._save_image(image_data, output_path, api_name)
|
||||
|
||||
else:
|
||||
# OpenAI兼容格式: choices[0].message.content
|
||||
content = result['choices'][0]['message']['content']
|
||||
logger.info(f"{api_name} 收到内容: {content[:200]}...")
|
||||
|
||||
# 使用原有的URL/Base64提取逻辑
|
||||
return await self._extract_and_save_image(content, output_path, api_name)
|
||||
|
||||
except KeyError as e:
|
||||
logger.error(f"{api_name} 响应格式不正确,缺少字段: {e}")
|
||||
logger.error(f"响应内容: {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||
return False, f"响应格式错误: {e}", {}
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} 处理响应时发生异常: {e}")
|
||||
return False, f"处理异常: {e}", {}
|
||||
|
||||
async def _save_image(self, image_data: bytes, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||||
"""保存图片文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
|
||||
logger.info(f"{api_name} 图片已保存到: {output_path}")
|
||||
|
||||
# 验证保存的图片
|
||||
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
||||
file_size = os.path.getsize(output_path)
|
||||
logger.info(f"{api_name} 图片保存成功,文件大小: {file_size} bytes")
|
||||
|
||||
return True, f"{api_name} 印花提取完成", {
|
||||
'output_path': output_path,
|
||||
'file_size': file_size,
|
||||
'api_used': api_name
|
||||
}
|
||||
else:
|
||||
logger.error(f"{api_name} 保存的图片文件无效")
|
||||
return False, "保存的图片文件无效", {}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} 保存图片时发生错误: {e}")
|
||||
return False, f"保存图片失败: {e}", {}
|
||||
|
||||
async def _extract_and_save_image(self, content: str, output_path: str, api_name: str) -> tuple[bool, str, dict]:
|
||||
"""从响应内容中提取并保存图片(URL或Base64格式)"""
|
||||
# 查找和处理图片数据
|
||||
image_data = None
|
||||
|
||||
# 方法1: 查找URL链接 (优先检查URL格式)
|
||||
url_match = re.search(r'https?://[^\s\)]+\.(?:png|jpg|jpeg|gif|webp)', content)
|
||||
if url_match:
|
||||
image_url = url_match.group(0)
|
||||
logger.info(f"{api_name} 找到图片URL: {image_url}")
|
||||
|
||||
# 图片下载重试机制
|
||||
download_retries = 3
|
||||
for download_attempt in range(download_retries):
|
||||
try:
|
||||
logger.info(f"{api_name} 开始下载图片 (第{download_attempt + 1}/{download_retries}次尝试): {image_url}")
|
||||
|
||||
# 异步下载图片,增加超时时间
|
||||
timeout = aiohttp.ClientTimeout(total=300, connect=60)
|
||||
connector = aiohttp.TCPConnector(limit=5, limit_per_host=2)
|
||||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'}
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
connector=connector,
|
||||
headers=headers
|
||||
) as download_session:
|
||||
logger.info(f"{api_name} 正在发送HTTP请求...")
|
||||
async with download_session.get(image_url) as img_response:
|
||||
logger.info(f"{api_name} 收到HTTP响应: {img_response.status}")
|
||||
if img_response.status == 200:
|
||||
image_data = await img_response.read()
|
||||
logger.info(f"{api_name} 图片下载成功,大小: {len(image_data)} bytes")
|
||||
break # 成功则跳出重试循环
|
||||
else:
|
||||
logger.error(f"{api_name} 图片下载失败,HTTP状态码: {img_response.status}")
|
||||
if download_attempt == download_retries - 1:
|
||||
return False, "图片下载失败", {}
|
||||
else:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} 下载图片时发生异常 (第{download_attempt + 1}次): {type(e).__name__}: {str(e)}")
|
||||
if download_attempt == download_retries - 1:
|
||||
return False, f"图片下载异常: {str(e)}", {}
|
||||
else:
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
else:
|
||||
# 方法2: 查找标准格式 data:image/type;base64,data
|
||||
base64_match = re.search(r'data:image/[^;]+;base64,([A-Za-z0-9+/=]+)', content)
|
||||
|
||||
if base64_match:
|
||||
base64_data = base64_match.group(1)
|
||||
logger.info(f"{api_name} 找到标准格式的Base64数据")
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||||
return False, f"Base64解码失败: {e}", {}
|
||||
else:
|
||||
# 方法3: 查找纯Base64数据(长字符串)
|
||||
base64_match = re.search(r'([A-Za-z0-9+/=]{100,})', content)
|
||||
if base64_match:
|
||||
base64_data = base64_match.group(1)
|
||||
logger.info(f"{api_name} 找到纯Base64数据")
|
||||
try:
|
||||
image_data = base64.b64decode(base64_data)
|
||||
except Exception as e:
|
||||
logger.error(f"{api_name} Base64解码失败: {e}")
|
||||
return False, f"Base64解码失败: {e}", {}
|
||||
else:
|
||||
logger.error(f"{api_name} 在响应中未找到图片数据")
|
||||
return False, "未找到图片数据", {}
|
||||
|
||||
# 检查图片数据
|
||||
if not image_data:
|
||||
logger.error(f"{api_name} 图片数据为空")
|
||||
return False, "图片数据为空", {}
|
||||
|
||||
# 保存图片
|
||||
return await self._save_image(image_data, output_path, api_name)
|
||||
|
||||
async def correct_perspective(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
level: str = "mild",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""
|
||||
透视矫正:先把有透视畸变的图还原为正面平铺视图,再做后续处理。
|
||||
|
||||
Args:
|
||||
input_path: 本地图片路径
|
||||
output_path: 矫正后输出路径
|
||||
level: "mild" 或 "strong"
|
||||
"""
|
||||
if level == "strong":
|
||||
prompt = (
|
||||
"这张图存在明显透视畸变(俯拍/斜拍/贴墙)。"
|
||||
"请对图片进行透视矫正:将主体变换为正面平铺视图,"
|
||||
"使所有边缘变成水平或垂直,去除梯形形变,"
|
||||
"保持图案颜色和细节完全不变,只矫正几何形状,输出矫正后的完整图片。"
|
||||
)
|
||||
else:
|
||||
prompt = (
|
||||
"这张图存在轻微透视畸变(衣物悬挂/桌面斜拍)。"
|
||||
"请做轻度透视矫正:将主体调整为尽量正视角,"
|
||||
"消除轻微的梯形拉伸感,保持图案颜色和细节不变,输出矫正后的图片。"
|
||||
)
|
||||
|
||||
# 透视矫正使用 1:1 比例避免比例失真
|
||||
return await self.extract_pattern(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
custom_prompt=prompt,
|
||||
aspect_ratio="1:1",
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
|
||||
# 便捷函数
|
||||
async def extract_pattern_v2(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
custom_prompt: str = None,
|
||||
aspect_ratio: str = "1:1",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""Gemini V2印花提取便捷函数"""
|
||||
service = GeminiExtractV2Service()
|
||||
try:
|
||||
return await service.extract_pattern(input_path, output_path, custom_prompt, aspect_ratio)
|
||||
finally:
|
||||
await service.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
import asyncio
|
||||
|
||||
async def test():
|
||||
service = GeminiExtractV2Service()
|
||||
|
||||
input_path = "F:/api/134.png"
|
||||
output_path = "test_output_v2.png"
|
||||
|
||||
success, message, data = await service.extract_pattern(input_path, output_path)
|
||||
|
||||
print(f"结果: {success}")
|
||||
print(f"消息: {message}")
|
||||
print(f"数据: {data}")
|
||||
|
||||
await service.cleanup()
|
||||
|
||||
asyncio.run(test())
|
||||
755
services/service_meitu.py
Normal file
755
services/service_meitu.py
Normal file
@@ -0,0 +1,755 @@
|
||||
"""
|
||||
美图API服务模块 - 处理与美图API的交互
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import logging
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, Callable, List, Tuple
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 统计信息
|
||||
class MeituServiceStats:
|
||||
def __init__(self):
|
||||
self.total_requests = 0
|
||||
self.successful_requests = 0
|
||||
self.failed_requests = 0
|
||||
self.timeout_requests = 0
|
||||
self.network_error_requests = 0
|
||||
self.last_success_time = None
|
||||
self.last_error_time = None
|
||||
self.last_error_message = None
|
||||
|
||||
def record_success(self):
|
||||
self.total_requests += 1
|
||||
self.successful_requests += 1
|
||||
self.last_success_time = time.time()
|
||||
|
||||
def record_failure(self, error_type="general", message=""):
|
||||
self.total_requests += 1
|
||||
self.failed_requests += 1
|
||||
self.last_error_time = time.time()
|
||||
self.last_error_message = message
|
||||
|
||||
if error_type == "timeout":
|
||||
self.timeout_requests += 1
|
||||
elif error_type == "network":
|
||||
self.network_error_requests += 1
|
||||
|
||||
def get_success_rate(self):
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.successful_requests / self.total_requests * 100
|
||||
|
||||
def get_stats(self):
|
||||
return {
|
||||
"total_requests": self.total_requests,
|
||||
"successful_requests": self.successful_requests,
|
||||
"failed_requests": self.failed_requests,
|
||||
"timeout_requests": self.timeout_requests,
|
||||
"network_error_requests": self.network_error_requests,
|
||||
"success_rate": self.get_success_rate(),
|
||||
"last_success_time": self.last_success_time,
|
||||
"last_error_time": self.last_error_time,
|
||||
"last_error_message": self.last_error_message
|
||||
}
|
||||
|
||||
# 全局统计实例
|
||||
_service_stats = MeituServiceStats()
|
||||
|
||||
class MeituServiceError(Exception):
|
||||
"""美图服务异常"""
|
||||
pass
|
||||
|
||||
class MeituTimeoutError(MeituServiceError):
|
||||
"""美图服务超时异常"""
|
||||
pass
|
||||
|
||||
class MeituNetworkError(MeituServiceError):
|
||||
"""美图服务网络异常"""
|
||||
pass
|
||||
|
||||
class MeituAPIService:
|
||||
"""美图API服务类,处理与美图API的所有交互"""
|
||||
|
||||
# 服务状态
|
||||
_service_status = {
|
||||
"available": False,
|
||||
"last_check": 0,
|
||||
"error": None
|
||||
}
|
||||
|
||||
# 支持的处理模式
|
||||
SUPPORTED_MODES = {
|
||||
"crystal": "极速重绘",
|
||||
"standard": "标准处理",
|
||||
"enhance": "增强处理",
|
||||
"hdr": "HDR处理",
|
||||
"portrait": "人像优化"
|
||||
}
|
||||
|
||||
def __init__(self, api_url: str = None):
|
||||
"""
|
||||
初始化美图API服务
|
||||
:param api_url: API基础URL,如果为None则使用环境变量或默认值
|
||||
"""
|
||||
# self.api_url = api_url or os.environ.get('MEITU_API_URL', 'http://89358zi786.goho.co:38226')
|
||||
self.api_url = api_url or os.environ.get('MEITU_API_URL', 'https://127.0.0.1:6668') # 本地
|
||||
self.stats = _service_stats # 使用全局统计实例
|
||||
self._active_tasks = set() # 追踪活跃任务,确保高并发安全
|
||||
self._task_cancellation_tokens = {} # 任务取消令牌
|
||||
|
||||
async def process_image(self,
|
||||
image_path: str,
|
||||
mode: str,
|
||||
output_dir: Path,
|
||||
progress_callback: Optional[Callable[[int, str], None]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片
|
||||
:param image_path: 图片路径
|
||||
:param mode: 处理模式(crystal, standard等)
|
||||
:param output_dir: 输出目录
|
||||
:param progress_callback: 进度回调函数
|
||||
:return: 处理结果,包含任务ID和结果图片路径
|
||||
"""
|
||||
# 检查模式是否支持
|
||||
if mode not in self.SUPPORTED_MODES:
|
||||
supported_modes = ", ".join(self.SUPPORTED_MODES.keys())
|
||||
raise MeituServiceError(f"不支持的处理模式: {mode},支持的模式: {supported_modes}")
|
||||
|
||||
# 检查图片文件是否存在
|
||||
if not os.path.exists(image_path):
|
||||
raise MeituServiceError(f"图片文件不存在: {image_path}")
|
||||
|
||||
try:
|
||||
# 生成唯一文件名
|
||||
unique_filename = os.path.basename(image_path)
|
||||
file_extension = os.path.splitext(unique_filename)[1]
|
||||
|
||||
# 上传图片并获取任务ID
|
||||
task_id = await self._upload_image(image_path, unique_filename, mode)
|
||||
|
||||
if not task_id:
|
||||
raise MeituServiceError("美图API上传失败,未获取到任务ID")
|
||||
|
||||
# 记录任务ID类型和值
|
||||
logger.info(f"获取到任务ID: {task_id}, 类型: {type(task_id)}")
|
||||
|
||||
# 注册活跃任务
|
||||
self._active_tasks.add(task_id)
|
||||
self._task_cancellation_tokens[task_id] = False
|
||||
|
||||
# 轮询等待处理完成 - 确保异常正确传播
|
||||
try:
|
||||
await self._wait_for_completion(task_id, progress_callback)
|
||||
except (MeituTimeoutError, MeituServiceError) as e:
|
||||
logger.error(f"美图处理被终止: {str(e)}")
|
||||
# 立即清理任务并重新抛出异常
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise
|
||||
finally:
|
||||
# 确保任务总是从活跃列表中移除
|
||||
self._active_tasks.discard(task_id)
|
||||
self._task_cancellation_tokens.pop(task_id, None)
|
||||
logger.info(f"任务{task_id}已从活跃列表中移除")
|
||||
|
||||
# 下载处理后的图片
|
||||
processed_filename = f"processed_{task_id}_{int(time.time())}{file_extension}"
|
||||
processed_path = output_dir / processed_filename
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# 下载结果图片
|
||||
try:
|
||||
await self._download_result(task_id, processed_path)
|
||||
except Exception as download_error:
|
||||
logger.error(f"下载处理结果失败: {str(download_error)}")
|
||||
raise MeituServiceError(f"下载处理结果失败: {str(download_error)}")
|
||||
|
||||
# 安全处理processing_time计算
|
||||
processing_time = 0
|
||||
try:
|
||||
if '_' in task_id:
|
||||
# 尝试从task_id中提取时间戳
|
||||
timestamp_str = task_id.split('_')[-1]
|
||||
logger.info(f"从task_id提取时间戳: {timestamp_str}, 类型: {type(timestamp_str)}")
|
||||
|
||||
# 安全转换为整数
|
||||
try:
|
||||
timestamp = int(timestamp_str)
|
||||
processing_time = time.time() - timestamp
|
||||
logger.info(f"计算处理时间: {processing_time}秒")
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"时间戳转换失败: {e}")
|
||||
else:
|
||||
logger.warning(f"任务ID不包含下划线分隔符: {task_id}")
|
||||
except Exception as time_error:
|
||||
logger.warning(f"处理时间计算错误: {str(time_error)}")
|
||||
|
||||
# 记录成功
|
||||
self.stats.record_success()
|
||||
logger.info(f"美图服务处理成功 - 任务ID: {task_id}, 耗时: {processing_time:.2f}秒")
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"processed_path": processed_path,
|
||||
"processed_filename": processed_filename,
|
||||
"mode": mode,
|
||||
"mode_name": self.SUPPORTED_MODES.get(mode, "未知模式"),
|
||||
"processing_time": processing_time
|
||||
}
|
||||
|
||||
except MeituTimeoutError as e:
|
||||
self.stats.record_failure("timeout", str(e))
|
||||
logger.error(f"美图服务超时: {str(e)}")
|
||||
raise
|
||||
except MeituNetworkError as e:
|
||||
self.stats.record_failure("network", str(e))
|
||||
logger.error(f"美图服务网络错误: {str(e)}")
|
||||
raise
|
||||
except MeituServiceError as e:
|
||||
self.stats.record_failure("general", str(e))
|
||||
logger.error(f"美图服务错误: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.stats.record_failure("unknown", str(e))
|
||||
logger.error(f"美图API处理失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||||
raise MeituServiceError(f"美图API处理失败: {str(e)}")
|
||||
|
||||
async def process_batch(self,
|
||||
image_paths: List[str],
|
||||
mode: str,
|
||||
output_dir: Path,
|
||||
max_concurrent: int = 3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量处理图片
|
||||
:param image_paths: 图片路径列表
|
||||
:param mode: 处理模式
|
||||
:param output_dir: 输出目录
|
||||
:param max_concurrent: 最大并发数
|
||||
:return: 处理结果列表
|
||||
"""
|
||||
results = []
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_single(image_path):
|
||||
async with semaphore:
|
||||
try:
|
||||
result = await self.process_image(image_path, mode, output_dir)
|
||||
return {
|
||||
"success": True,
|
||||
"result": result,
|
||||
"image_path": image_path
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {image_path}, 错误: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"image_path": image_path
|
||||
}
|
||||
|
||||
# 创建任务
|
||||
tasks = [process_single(path) for path in image_paths]
|
||||
|
||||
# 等待所有任务完成
|
||||
try:
|
||||
results = await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
logger.error(f"批量处理任务异常: {str(e)}")
|
||||
# 即使有异常,也返回已完成的结果
|
||||
results = [{"success": False, "error": f"批量处理异常: {str(e)}", "image_path": path} for path in image_paths]
|
||||
|
||||
return results
|
||||
|
||||
async def _upload_image(self, image_path: str, filename: str, mode: str) -> str:
|
||||
"""
|
||||
上传图片到美图API
|
||||
:param image_path: 图片路径
|
||||
:param filename: 文件名
|
||||
:param mode: 处理模式
|
||||
:return: 任务ID
|
||||
"""
|
||||
logger.info(f"上传图片到美图API - 文件: {filename}, 模式: {mode}")
|
||||
|
||||
# 检查文件大小
|
||||
file_size = os.path.getsize(image_path)
|
||||
if file_size > 10 * 1024 * 1024: # 10MB
|
||||
raise MeituServiceError(f"文件太大: {file_size / 1024 / 1024:.2f}MB,最大允许10MB")
|
||||
|
||||
# 检查文件类型
|
||||
content_type = await self._get_content_type(filename)
|
||||
if not content_type.startswith('image/'):
|
||||
raise MeituServiceError(f"不支持的文件类型: {content_type}")
|
||||
|
||||
# 重试逻辑
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# 设置上传专用的超时配置
|
||||
upload_timeout = aiohttp.ClientTimeout(total=60, connect=10) # 上传允许更长时间
|
||||
async with aiohttp.ClientSession(timeout=upload_timeout) as session:
|
||||
with open(image_path, 'rb') as f:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field(filename,
|
||||
f.read(),
|
||||
filename=filename,
|
||||
content_type=content_type)
|
||||
|
||||
async with session.post(f'{self.api_url}/add_task?scene={mode}', data=data, ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"美图API上传失败 - 状态码: {resp.status}, 响应: {error_text}")
|
||||
raise MeituServiceError(f"美图API上传失败: {error_text}")
|
||||
|
||||
response_text = await resp.text()
|
||||
logger.info(f"美图API上传响应: {response_text}")
|
||||
|
||||
try:
|
||||
result = json.loads(response_text)
|
||||
if result.get('code') != 0:
|
||||
raise MeituServiceError(result.get('message', '美图API上传失败'))
|
||||
|
||||
task_id = result.get('id')
|
||||
logger.info(f"美图API上传成功 - 任务ID: {task_id}")
|
||||
return task_id
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"美图API响应解析失败: {response_text}")
|
||||
raise MeituServiceError("美图API响应格式错误")
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
last_error = e
|
||||
wait_time = 2 ** retry_count # 指数退避
|
||||
logger.warning(f"上传失败,正在重试 ({retry_count}/{max_retries}),等待 {wait_time} 秒: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# 所有重试都失败
|
||||
raise last_error or MeituServiceError("上传图片失败,已达到最大重试次数")
|
||||
|
||||
async def _wait_for_completion(self,
|
||||
task_id: str,
|
||||
progress_callback: Optional[Callable[[int, str], None]] = None,
|
||||
max_wait_time: int = 180, # 最长等待3分钟(缩短超时时间)
|
||||
check_interval: int = 2 # 每2秒检查一次
|
||||
) -> None:
|
||||
"""
|
||||
轮询等待处理完成
|
||||
:param task_id: 任务ID
|
||||
:param progress_callback: 进度回调函数
|
||||
:param max_wait_time: 最长等待时间(秒)
|
||||
:param check_interval: 检查间隔(秒)
|
||||
"""
|
||||
logger.info(f"开始等待美图API处理完成 - 任务ID: {task_id}")
|
||||
|
||||
start_time = time.time()
|
||||
progress = 0
|
||||
|
||||
# 连续失败计数
|
||||
consecutive_failures = 0
|
||||
max_consecutive_failures = 3 # 减少连续失败次数,更快失败
|
||||
|
||||
# 早期检测到的致命错误类型
|
||||
detected_fatal_errors = set()
|
||||
|
||||
# 创建当前任务的取消令牌
|
||||
is_cancelled = False
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as session:
|
||||
for i in range(max_wait_time // check_interval):
|
||||
# 检查是否被取消 - 支持外部取消令牌
|
||||
if is_cancelled or self._task_cancellation_tokens.get(task_id, False):
|
||||
logger.error(f"美图处理任务被取消: {task_id}")
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise MeituTimeoutError("美图处理任务被取消")
|
||||
|
||||
# 检查任务是否还在活跃列表中
|
||||
if task_id not in self._active_tasks:
|
||||
logger.warning(f"任务{task_id}不在活跃列表中,可能已被清理")
|
||||
raise MeituTimeoutError("任务已被停止")
|
||||
|
||||
# 计算预估进度
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed < max_wait_time:
|
||||
# 假设进度在90%以内线性增长
|
||||
progress = min(90, int(90 * elapsed / max_wait_time))
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(progress, f"美图API处理中 - {progress}%")
|
||||
|
||||
try:
|
||||
# 在每次API调用前再次检查取消状态
|
||||
if self._task_cancellation_tokens.get(task_id, False):
|
||||
logger.info(f"任务{task_id}在API调用前被取消")
|
||||
raise MeituTimeoutError("任务被取消")
|
||||
|
||||
async with session.get(f'{self.api_url}/try_get?taskid={task_id}', ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
consecutive_failures += 1
|
||||
error_text = await resp.text()
|
||||
logger.error(f"美图API状态检查失败 - 状态码: {resp.status}, 响应: {error_text}")
|
||||
|
||||
# 如果连续失败次数过多,抛出异常
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error(f"美图API连续失败{consecutive_failures}次,尝试清理任务")
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise MeituTimeoutError(f"美图API连续失败{consecutive_failures}次: {error_text}")
|
||||
|
||||
# 否则继续等待
|
||||
await asyncio.sleep(check_interval * 2) # 失败时等待更长时间
|
||||
continue
|
||||
|
||||
# 重置失败计数
|
||||
consecutive_failures = 0
|
||||
|
||||
# 获取响应内容
|
||||
try:
|
||||
response_text = await resp.text()
|
||||
logger.debug(f"美图API状态检查响应: {response_text}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
result = json.loads(response_text)
|
||||
logger.debug(f"解析的JSON响应: {result}")
|
||||
|
||||
# 获取code字段
|
||||
if 'code' not in result:
|
||||
logger.warning(f"美图API响应中缺少code字段: {result}")
|
||||
# 继续等待,不抛出异常
|
||||
await asyncio.sleep(check_interval)
|
||||
continue
|
||||
|
||||
# 安全获取code值
|
||||
try:
|
||||
code = result['code']
|
||||
# 确保code是整数
|
||||
if not isinstance(code, (int, float)):
|
||||
code = int(code) if str(code).isdigit() else -1
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"无法转换code为整数: {e}, 原始值: {result.get('code')}")
|
||||
code = -1
|
||||
|
||||
# 检查处理状态
|
||||
if code < 0:
|
||||
error_message = str(result.get('message', '美图API处理失败'))
|
||||
logger.error(f"美图API处理失败: {error_message}")
|
||||
|
||||
# 检查是否是超时或严重错误,如果是则立即终止
|
||||
timeout_indicators = [
|
||||
"TimeoutException",
|
||||
"selenium.common.exceptions.TimeoutException",
|
||||
"WebDriver",
|
||||
"已终止处理",
|
||||
"处理超时",
|
||||
"连接超时",
|
||||
"响应超时"
|
||||
]
|
||||
|
||||
is_timeout = any(indicator in error_message for indicator in timeout_indicators)
|
||||
|
||||
if is_timeout:
|
||||
logger.error(f"检测到严重错误/超时,立即终止美图处理: {error_message}")
|
||||
raise MeituTimeoutError("美图处理超时或遇到严重错误,已终止处理")
|
||||
|
||||
raise MeituServiceError(error_message)
|
||||
elif code == 0:
|
||||
logger.info(f"美图API处理完成 - 任务ID: {task_id}")
|
||||
if progress_callback:
|
||||
progress_callback(100, "美图API处理完成")
|
||||
return
|
||||
else:
|
||||
# 其他状态码,继续等待
|
||||
logger.info(f"美图API处理中 - 状态码: {code}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON解析失败: {e}, 响应内容: {response_text[:100]}...")
|
||||
# 继续等待,不抛出异常
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
logger.error(f"处理响应内容异常: {error_str}")
|
||||
import traceback
|
||||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||||
|
||||
# 检查是否是致命错误,如果是则立即终止
|
||||
fatal_error_indicators = [
|
||||
"MeituTimeoutError",
|
||||
"美图处理超时",
|
||||
"已终止处理",
|
||||
"WebDriver",
|
||||
"TimeoutException"
|
||||
]
|
||||
|
||||
is_fatal = any(indicator in error_str for indicator in fatal_error_indicators)
|
||||
if is_fatal:
|
||||
# 记录检测到的致命错误类型
|
||||
for indicator in fatal_error_indicators:
|
||||
if indicator in error_str:
|
||||
detected_fatal_errors.add(indicator)
|
||||
|
||||
logger.error(f"检测到致命错误,立即终止处理: {error_str}")
|
||||
logger.error(f"已检测到的致命错误类型: {detected_fatal_errors}")
|
||||
is_cancelled = True # 设置取消标志
|
||||
await self.cleanup_failed_task(task_id) # 立即清理
|
||||
raise MeituTimeoutError(f"美图处理遇到致命错误: {error_str}")
|
||||
|
||||
# 增加连续失败计数
|
||||
consecutive_failures += 1
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error(f"连续失败{consecutive_failures}次,终止处理")
|
||||
raise MeituServiceError(f"美图处理连续失败{consecutive_failures}次")
|
||||
except asyncio.TimeoutError:
|
||||
consecutive_failures += 1
|
||||
logger.error(f"美图API请求超时 - 第{consecutive_failures}次")
|
||||
|
||||
# 如果连续超时次数过多,抛出异常
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error(f"美图API连续超时{consecutive_failures}次,尝试清理任务")
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise MeituTimeoutError(f"美图API连续超时{consecutive_failures}次,任务可能已失败")
|
||||
|
||||
await asyncio.sleep(check_interval * 2)
|
||||
continue
|
||||
except aiohttp.ClientError as e:
|
||||
consecutive_failures += 1
|
||||
logger.error(f"网络请求异常: {str(e)} - 第{consecutive_failures}次")
|
||||
|
||||
# 如果连续失败次数过多,抛出异常
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error(f"美图API网络连接连续失败{consecutive_failures}次,尝试清理任务")
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise MeituNetworkError(f"美图API网络连接连续失败{consecutive_failures}次")
|
||||
|
||||
await asyncio.sleep(check_interval * 2)
|
||||
continue
|
||||
except Exception as e:
|
||||
if isinstance(e, MeituServiceError):
|
||||
raise
|
||||
|
||||
consecutive_failures += 1
|
||||
logger.error(f"美图API状态检查异常: {str(e)} - 第{consecutive_failures}次")
|
||||
import traceback
|
||||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||||
|
||||
# 如果连续失败次数过多,抛出异常
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
logger.error(f"美图API状态检查连续异常{consecutive_failures}次,尝试清理任务")
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise MeituServiceError(f"美图API状态检查连续异常{consecutive_failures}次: {str(e)}")
|
||||
|
||||
await asyncio.sleep(check_interval * 2)
|
||||
continue
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
# 超时
|
||||
logger.error(f"美图API处理超时({max_wait_time}秒),尝试清理任务")
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise MeituTimeoutError(f"美图API处理超时({max_wait_time}秒) - 任务ID: {task_id}")
|
||||
|
||||
async def _download_result(self, task_id: str, output_path: Path) -> None:
|
||||
"""
|
||||
下载处理结果
|
||||
:param task_id: 任务ID
|
||||
:param output_path: 输出路径
|
||||
"""
|
||||
logger.info(f"开始下载美图API处理结果 - 任务ID: {task_id}, 输出路径: {output_path}")
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# 重试逻辑
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# 设置下载专用的超时配置
|
||||
download_timeout = aiohttp.ClientTimeout(total=120, connect=10) # 下载允许更长时间
|
||||
async with aiohttp.ClientSession(timeout=download_timeout) as session:
|
||||
async with session.get(f'{self.api_url}/get_image?taskid={task_id}', ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"美图API下载失败 - 状态码: {resp.status}, 响应: {error_text}")
|
||||
raise MeituServiceError(f"美图API下载失败: {error_text}")
|
||||
|
||||
# 读取图片数据
|
||||
image_data = await resp.read()
|
||||
if not image_data:
|
||||
raise MeituServiceError("美图API返回空图片数据")
|
||||
|
||||
# 保存图片
|
||||
async with aiofiles.open(output_path, 'wb') as f:
|
||||
await f.write(image_data)
|
||||
|
||||
# 安全获取文件大小
|
||||
try:
|
||||
file_size = len(image_data)
|
||||
logger.info(f"美图API结果下载成功 - 任务ID: {task_id}, 文件大小: {file_size} 字节")
|
||||
except Exception as size_error:
|
||||
logger.warning(f"获取文件大小失败: {str(size_error)}")
|
||||
logger.info(f"美图API结果下载成功 - 任务ID: {task_id}")
|
||||
|
||||
# 验证文件是否成功保存
|
||||
if not output_path.exists() or output_path.stat().st_size == 0:
|
||||
raise MeituServiceError("图片保存失败或文件大小为0")
|
||||
|
||||
return
|
||||
except MeituServiceError:
|
||||
# 直接抛出服务特定错误
|
||||
raise
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
last_error = e
|
||||
wait_time = 2 ** retry_count # 指数退避
|
||||
logger.warning(f"下载失败,正在重试 ({retry_count}/{max_retries}),等待 {wait_time} 秒: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# 所有重试都失败
|
||||
error_msg = str(last_error) if last_error else "未知错误"
|
||||
logger.error(f"下载结果失败,已达到最大重试次数: {error_msg}")
|
||||
raise MeituServiceError(f"下载结果失败,已达到最大重试次数: {error_msg}")
|
||||
|
||||
@staticmethod
|
||||
async def _get_content_type(filename: str) -> str:
|
||||
"""
|
||||
根据文件名获取内容类型
|
||||
:param filename: 文件名
|
||||
:return: 内容类型
|
||||
"""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
content_types = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp',
|
||||
'.bmp': 'image/bmp'
|
||||
}
|
||||
return content_types.get(ext, 'application/octet-stream')
|
||||
|
||||
@classmethod
|
||||
async def test_connection(cls, api_url: str = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试与美图API的连接
|
||||
:param api_url: API URL,如果为None则使用默认值
|
||||
:return: (连接是否成功, 状态消息)
|
||||
"""
|
||||
# 如果最后检查时间在30秒内,直接返回缓存的状态
|
||||
if time.time() - cls._service_status["last_check"] < 30:
|
||||
return (
|
||||
cls._service_status["available"],
|
||||
"服务在线" if cls._service_status["available"] else f"服务离线: {cls._service_status['error']}"
|
||||
)
|
||||
|
||||
service = MeituAPIService(api_url)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# 尝试健康检查端点
|
||||
try:
|
||||
async with session.get(f'{service.api_url}/health', ssl=False, timeout=5) as resp:
|
||||
if resp.status == 200:
|
||||
cls._service_status = {
|
||||
"available": True,
|
||||
"last_check": time.time(),
|
||||
"error": None
|
||||
}
|
||||
return True, "服务在线"
|
||||
except:
|
||||
# 健康检查失败,尝试其他端点
|
||||
pass
|
||||
|
||||
# 尝试调用API的其他端点
|
||||
try:
|
||||
async with session.get(f'{service.api_url}/', ssl=False, timeout=5) as resp:
|
||||
cls._service_status = {
|
||||
"available": resp.status < 500,
|
||||
"last_check": time.time(),
|
||||
"error": f"状态码: {resp.status}" if resp.status >= 400 else None
|
||||
}
|
||||
return cls._service_status["available"], "服务可能在线,但健康检查失败"
|
||||
except Exception as e:
|
||||
cls._service_status = {
|
||||
"available": False,
|
||||
"last_check": time.time(),
|
||||
"error": str(e)
|
||||
}
|
||||
return False, f"服务离线: {str(e)}"
|
||||
except Exception as e:
|
||||
cls._service_status = {
|
||||
"available": False,
|
||||
"last_check": time.time(),
|
||||
"error": str(e)
|
||||
}
|
||||
logger.error(f"美图API连接测试失败: {str(e)}")
|
||||
return False, f"服务离线: {str(e)}"
|
||||
|
||||
async def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消处理任务"""
|
||||
try:
|
||||
# 设置取消任务的超时配置
|
||||
cancel_timeout = aiohttp.ClientTimeout(total=15, connect=5) # 取消操作超时短一些
|
||||
async with aiohttp.ClientSession(timeout=cancel_timeout) as session:
|
||||
async with session.post(f'{self.api_url}/cancel_task',
|
||||
data={'taskid': task_id},
|
||||
ssl=False) as resp:
|
||||
if resp.status == 200:
|
||||
logger.info(f"任务取消成功: {task_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"任务取消失败: {task_id}, 状态码: {resp.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"取消任务异常: {task_id}, 错误: {str(e)}")
|
||||
return False
|
||||
|
||||
async def cancel_all_tasks(self) -> None:
|
||||
"""取消所有活跃任务"""
|
||||
logger.warning(f"取消所有活跃任务,共{len(self._active_tasks)}个")
|
||||
|
||||
for task_id in list(self._active_tasks):
|
||||
self._task_cancellation_tokens[task_id] = True
|
||||
await self.cleanup_failed_task(task_id)
|
||||
|
||||
self._active_tasks.clear()
|
||||
self._task_cancellation_tokens.clear()
|
||||
|
||||
async def cleanup_failed_task(self, task_id: str) -> None:
|
||||
"""清理失败的任务"""
|
||||
try:
|
||||
# 尝试取消任务
|
||||
await self.cancel_task(task_id)
|
||||
|
||||
# 清理可能的临时文件
|
||||
# 这里可以添加更多清理逻辑
|
||||
logger.info(f"失败任务清理完成: {task_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理失败任务异常: {task_id}, 错误: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def get_service_stats(cls) -> Dict[str, Any]:
|
||||
"""获取服务统计信息"""
|
||||
return _service_stats.get_stats()
|
||||
|
||||
@classmethod
|
||||
def reset_stats(cls) -> None:
|
||||
"""重置服务统计信息"""
|
||||
global _service_stats
|
||||
_service_stats = MeituServiceStats()
|
||||
|
||||
@classmethod
|
||||
def get_supported_modes(cls) -> Dict[str, str]:
|
||||
"""获取支持的处理模式"""
|
||||
return cls.SUPPORTED_MODES
|
||||
219
services/service_qwen.py
Normal file
219
services/service_qwen.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import time
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
api_key = '8e32d44e3007447cb4be6ee52c5d3110'
|
||||
|
||||
|
||||
class UploadInfo(BaseModel):
|
||||
fileName: str
|
||||
fileType: str
|
||||
|
||||
|
||||
class CreateInfo(BaseModel):
|
||||
taskId: str # 创建的任务 ID,可用于查询状态或获取结果
|
||||
taskStatus: str # 初始状态,可能为:QUEUED、RUNNING、FAILED
|
||||
clientId: str # 平台内部标识,用于排错,无需关注
|
||||
netWssUrl: str # WebSocket 地址(当前不稳定,不推荐使用)
|
||||
promptTips: str # ComfyUI 校验信息(字符串格式的 JSON),可用于识别配置异常节点
|
||||
|
||||
|
||||
class RunHubResponse(BaseModel):
|
||||
code: int # 状态码,0 表示成功
|
||||
msg: str # 提示信息
|
||||
data: UploadInfo | CreateInfo | str | None = None # 数据对象
|
||||
|
||||
class Config:
|
||||
extra = 'allow' # 允许添加额外字段
|
||||
|
||||
|
||||
async def upload(img_path: str) -> RunHubResponse:
|
||||
with open(img_path, 'rb') as f:
|
||||
img_data = f.read()
|
||||
|
||||
form = aiohttp.FormData()
|
||||
form.add_field('apiKey', api_key)
|
||||
form.add_field('file', img_data, filename='image.jpg', content_type='image/jpeg')
|
||||
form.add_field('fileType', 'image')
|
||||
|
||||
url = 'https://www.runninghub.cn/task/openapi/upload'
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, data=form) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
return RunHubResponse.model_validate(response)
|
||||
|
||||
|
||||
async def create(workflow_id: str, node_info_list: list[dict[str, str]]) -> RunHubResponse:
|
||||
url = 'https://www.runninghub.cn/task/openapi/create'
|
||||
json_data = {'apiKey': api_key, 'workflowId': workflow_id, 'nodeInfoList': node_info_list}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=json_data) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
return RunHubResponse.model_validate(response)
|
||||
|
||||
|
||||
async def status(task_id: str) -> RunHubResponse:
|
||||
# 查询状态
|
||||
url = 'https://www.runninghub.cn/task/openapi/status'
|
||||
payload = {'apiKey': api_key, 'taskId': task_id}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
# ["QUEUED","RUNNING","FAILED","SUCCESS"]
|
||||
return RunHubResponse.model_validate(response)
|
||||
|
||||
|
||||
async def outputs(task_id: str) -> dict:
|
||||
# 获取结果
|
||||
url = 'https://www.runninghub.cn/task/openapi/outputs'
|
||||
payload = {'apiKey': api_key, 'taskId': task_id}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def 花纹提取_api(img_path: str, save_path: str, prompt: str = '') -> bool:
|
||||
"""
|
||||
异步花纹提取API
|
||||
|
||||
Args:
|
||||
img_path: 输入图片路径
|
||||
save_path: 输出图片路径
|
||||
prompt: 自定义提示词,为空则使用默认提示词
|
||||
|
||||
Returns:
|
||||
bool: 处理是否成功
|
||||
"""
|
||||
try:
|
||||
upload_res = await upload(img_path=img_path)
|
||||
if upload_res.code != 0 or not upload_res.data:
|
||||
logger.error(f"Qwen上传失败: code={upload_res.code}, msg={upload_res.msg}")
|
||||
return False
|
||||
|
||||
# 确保 data 是 UploadInfo 类型
|
||||
if not hasattr(upload_res.data, 'fileName'):
|
||||
logger.error(f"Qwen上传返回数据格式错误: {upload_res.data}")
|
||||
return False
|
||||
|
||||
logger.info(f"Qwen上传成功: {upload_res.data.fileName}")
|
||||
|
||||
workflow_id = '1980864078929379330'
|
||||
if len(prompt) == 0:
|
||||
prompt = '提取桌布上的花纹,自动补全空白,使得所有位置饱满并且完美衔接,去除所有的皱纹和扭曲和凸凹不平,图案自动摆正对齐并且铺平,使直线变得笔直,平行的花纹更有规律,没有残缺的花纹和折痕和断痕,铺满画布,完整的图案。简单的纯色背景'
|
||||
|
||||
node_info_list = [
|
||||
{
|
||||
'nodeId': '78',
|
||||
'fieldName': 'image',
|
||||
'fieldValue': upload_res.data.fileName,
|
||||
},
|
||||
{
|
||||
'nodeId': '103',
|
||||
'fieldName': 'text',
|
||||
'fieldValue': prompt,
|
||||
},
|
||||
]
|
||||
create_res = await create(workflow_id=workflow_id, node_info_list=node_info_list)
|
||||
|
||||
if create_res.code != 0 or not create_res.data:
|
||||
logger.error(f"Qwen任务创建失败: code={create_res.code}, msg={create_res.msg}")
|
||||
return False
|
||||
|
||||
# 确保 data 是 CreateInfo 类型
|
||||
if not hasattr(create_res.data, 'taskId'):
|
||||
logger.error(f"Qwen任务创建返回数据格式错误: {create_res.data}")
|
||||
return False
|
||||
|
||||
task_id = create_res.data.taskId
|
||||
logger.info(f"Qwen任务创建成功: {task_id}")
|
||||
|
||||
# 轮询检查状态
|
||||
max_retries = 120 # 最多等待10分钟(120次 * 5秒)
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
status_res = await status(task_id=task_id)
|
||||
if status_res.code == 0:
|
||||
if status_res.data == 'QUEUED':
|
||||
logger.info('Qwen队列排队中...')
|
||||
elif status_res.data == 'RUNNING':
|
||||
logger.info('Qwen正在处理中...')
|
||||
elif status_res.data == 'FAILED':
|
||||
logger.error(f'Qwen处理失败: {status_res}')
|
||||
return False
|
||||
elif status_res.data == 'SUCCESS':
|
||||
logger.info('Qwen处理完成,开始下载结果')
|
||||
outputs_res = await outputs(task_id=task_id)
|
||||
img_url = outputs_res['data'][0]['fileUrl']
|
||||
|
||||
# 下载结果图片
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(img_url) as resp:
|
||||
img_data = await resp.read()
|
||||
|
||||
with open(save_path, 'wb') as f:
|
||||
f.write(img_data)
|
||||
|
||||
logger.info(f"Qwen结果保存成功: {save_path}")
|
||||
try:
|
||||
from utils.api_cost_tracker import record
|
||||
record("qwen_enhance", count=1)
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
await asyncio.sleep(5) # 每5秒检查一次
|
||||
retry_count += 1
|
||||
else:
|
||||
logger.error(f'Qwen处理失败: {status_res}')
|
||||
return False
|
||||
|
||||
logger.error(f"Qwen处理超时,超过{max_retries * 5}秒")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Qwen花纹提取异常: {e}")
|
||||
import traceback
|
||||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
async def 清晰化_api(img_path: str, save_path: str) -> bool:
|
||||
"""
|
||||
高清增强:对透视矫正后的图案进行清晰化处理。
|
||||
使用与花纹提取相同的 ComfyUI 工作流,但提示词聚焦于清晰度增强。
|
||||
|
||||
Args:
|
||||
img_path: 输入图片路径(透视矫正后的结果)
|
||||
save_path: 输出图片路径
|
||||
|
||||
Returns:
|
||||
bool: 处理是否成功
|
||||
"""
|
||||
prompt = (
|
||||
"对这张已展平的图案进行高清增强处理:"
|
||||
"提升整体清晰度和锐利度,修复模糊边缘,补全细节纹理,"
|
||||
"使图案线条清晰笔直,颜色鲜艳均匀,"
|
||||
"去除噪点和压缩痕迹,输出印刷级高质量平面图,"
|
||||
"背景保持纯白色,不要改变图案内容和构图。"
|
||||
)
|
||||
return await 花纹提取_api(img_path=img_path, save_path=save_path, prompt=prompt)
|
||||
|
||||
|
||||
# 测试代码(注释掉)
|
||||
# if __name__ == "__main__":
|
||||
# asyncio.run(花纹提取_api(img_path=r'1.jpg', save_path='save1.png', prompt=''))
|
||||
433
services/service_vectorizer.py
Normal file
433
services/service_vectorizer.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
矢量化服务模块 - 使用统一异常处理机制
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import time
|
||||
import urllib3
|
||||
from typing import Callable, Optional, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
# 导入基础服务类
|
||||
from utils.service_base import (
|
||||
BaseService, PollingMixin, ServiceError, ServiceTimeoutError,
|
||||
ServiceNetworkError, RetryConfig, TimeoutConfig
|
||||
)
|
||||
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorizerServiceError(ServiceError):
|
||||
"""矢量化服务特定异常"""
|
||||
pass
|
||||
|
||||
|
||||
class VectorizerService(BaseService, PollingMixin):
|
||||
"""矢量化服务类 - 继承统一异常处理机制"""
|
||||
|
||||
# def __init__(self, base_url: str = "https://frp-dad.com:50529"):
|
||||
def __init__(self, base_url: str = "https://127.0.0.1:8090"):
|
||||
# 配置重试和超时
|
||||
retry_config = RetryConfig(
|
||||
max_retries=3,
|
||||
base_delay=2.0,
|
||||
max_delay=30.0
|
||||
)
|
||||
|
||||
timeout_config = TimeoutConfig(
|
||||
connection_timeout=60.0,
|
||||
read_timeout=240.0,
|
||||
total_timeout=1200.0
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="VectorizerService",
|
||||
base_url=base_url,
|
||||
retry_config=retry_config,
|
||||
timeout_config=timeout_config
|
||||
)
|
||||
|
||||
async def image_to_eps(self,
|
||||
image_path: str,
|
||||
save_eps_path: Optional[str] = None,
|
||||
timeout: int = 1200,
|
||||
poll_interval: float = 2.0,
|
||||
status_callback: Optional[Callable[[str, dict], None]] = None) -> str:
|
||||
"""
|
||||
将图片转换为EPS矢量文件
|
||||
|
||||
Args:
|
||||
image_path: 输入图片路径
|
||||
save_eps_path: 输出EPS文件路径(可选)
|
||||
timeout: 最大等待时间(秒)
|
||||
poll_interval: 轮询间隔(秒)
|
||||
status_callback: 状态回调函数
|
||||
|
||||
Returns:
|
||||
str: EPS文件保存路径
|
||||
|
||||
Raises:
|
||||
VectorizerServiceError: 矢量化服务异常
|
||||
ServiceTimeoutError: 超时异常
|
||||
ServiceNetworkError: 网络异常
|
||||
"""
|
||||
# 验证输入文件
|
||||
if not os.path.exists(image_path):
|
||||
raise VectorizerServiceError(f"输入图片文件不存在: {image_path}")
|
||||
|
||||
# 设置输出路径
|
||||
if save_eps_path is None:
|
||||
save_eps_path = os.path.splitext(image_path)[0] + '.eps'
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(save_eps_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 1. 上传图片
|
||||
task_id = await self.execute_with_retry(
|
||||
self._upload_image,
|
||||
image_path,
|
||||
status_callback,
|
||||
error_context=" - 上传图片"
|
||||
)
|
||||
|
||||
# 2. 轮询等待处理完成
|
||||
await self.execute_with_retry(
|
||||
self._wait_for_processing,
|
||||
task_id,
|
||||
timeout,
|
||||
poll_interval,
|
||||
status_callback,
|
||||
error_context=" - 等待处理完成"
|
||||
)
|
||||
|
||||
# 3. 下载结果文件
|
||||
await self.execute_with_retry(
|
||||
self._download_result,
|
||||
task_id,
|
||||
save_eps_path,
|
||||
status_callback,
|
||||
error_context=" - 下载结果文件"
|
||||
)
|
||||
|
||||
# 验证输出文件
|
||||
if not os.path.exists(save_eps_path) or os.path.getsize(save_eps_path) == 0:
|
||||
raise VectorizerServiceError(f"输出文件创建失败或为空: {save_eps_path}")
|
||||
|
||||
self.logger.info(f"矢量化转换成功: {image_path} -> {save_eps_path}")
|
||||
|
||||
if status_callback:
|
||||
status_callback('finished', {
|
||||
'message': '转换完成!',
|
||||
'input_path': image_path,
|
||||
'output_path': save_eps_path
|
||||
})
|
||||
|
||||
return save_eps_path
|
||||
|
||||
except (ServiceTimeoutError, ServiceNetworkError, VectorizerServiceError):
|
||||
# 直接传递这些异常
|
||||
if status_callback:
|
||||
status_callback('error', {'message': '处理失败,请稍后重试'})
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"矢量化转换失败: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
if status_callback:
|
||||
status_callback('error', {'message': error_msg})
|
||||
raise VectorizerServiceError(error_msg)
|
||||
|
||||
async def _upload_image(self,
|
||||
image_path: str,
|
||||
status_callback: Optional[Callable[[str, dict], None]] = None) -> str:
|
||||
"""上传图片到矢量化服务"""
|
||||
if status_callback:
|
||||
status_callback('uploading', {'message': '正在上传图片...', 'image_path': image_path})
|
||||
|
||||
async with await self.create_http_session() as session:
|
||||
with open(image_path, 'rb') as f:
|
||||
data = aiohttp.FormData()
|
||||
data.add_field('file', f, filename=os.path.basename(image_path))
|
||||
|
||||
async with session.post(f"{self.base_url}/add_task", data=data, ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise VectorizerServiceError(f"上传失败 - HTTP {resp.status}: {error_text}")
|
||||
|
||||
response_data = await resp.json()
|
||||
|
||||
if response_data.get('code') != 0:
|
||||
error_msg = response_data.get('message', '未知错误')
|
||||
raise VectorizerServiceError(f"上传失败: {error_msg}")
|
||||
|
||||
task_id = response_data.get('id') or response_data.get('taskid')
|
||||
if not task_id:
|
||||
raise VectorizerServiceError("上传失败,未获取到任务ID")
|
||||
|
||||
self.logger.info(f"图片上传成功,任务ID: {task_id}")
|
||||
|
||||
if status_callback:
|
||||
status_callback('uploaded', {
|
||||
'message': '图片上传成功,开始处理...',
|
||||
'taskid': task_id
|
||||
})
|
||||
|
||||
return task_id
|
||||
|
||||
async def _wait_for_processing(self,
|
||||
task_id: str,
|
||||
timeout: int,
|
||||
poll_interval: float,
|
||||
status_callback: Optional[Callable[[str, dict], None]] = None):
|
||||
"""等待处理完成"""
|
||||
start_time = time.time()
|
||||
poll_count = 0
|
||||
consecutive_failures = 0
|
||||
max_consecutive_failures = 5
|
||||
|
||||
async with await self.create_http_session() as session:
|
||||
while True:
|
||||
poll_count += 1
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 检查超时
|
||||
if elapsed_time > timeout:
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise ServiceTimeoutError(f"处理超时 (超过{timeout}秒) - 任务ID: {task_id}")
|
||||
|
||||
if status_callback:
|
||||
progress = min(90, int(90 * elapsed_time / timeout))
|
||||
status_callback('processing', {
|
||||
'message': f'正在处理中... (第{poll_count}次检查)',
|
||||
'taskid': task_id,
|
||||
'elapsed_time': elapsed_time,
|
||||
'poll_count': poll_count,
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
try:
|
||||
async with session.get(f"{self.base_url}/try_get",
|
||||
params={'taskid': task_id},
|
||||
ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
consecutive_failures += 1
|
||||
error_text = await resp.text()
|
||||
self.logger.warning(f"状态查询失败 - HTTP {resp.status}: {error_text}")
|
||||
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise VectorizerServiceError(f"连续{consecutive_failures}次状态查询失败")
|
||||
|
||||
await asyncio.sleep(poll_interval * 2) # 失败时等待更长时间
|
||||
continue
|
||||
|
||||
# 重置失败计数
|
||||
consecutive_failures = 0
|
||||
|
||||
result = await resp.json()
|
||||
|
||||
if result.get('code') == 0:
|
||||
# 处理完成
|
||||
self.logger.info(f"任务处理完成 - 任务ID: {task_id}, 耗时: {elapsed_time:.2f}秒")
|
||||
if status_callback:
|
||||
status_callback('completed', {
|
||||
'message': '处理完成,准备下载...',
|
||||
'taskid': task_id,
|
||||
'total_time': elapsed_time,
|
||||
'poll_count': poll_count
|
||||
})
|
||||
return
|
||||
|
||||
elif result.get('code') == -1:
|
||||
# 处理失败
|
||||
error_msg = result.get('message', '处理失败')
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise VectorizerServiceError(f"服务器处理失败: {error_msg}")
|
||||
|
||||
# 其他状态码,继续等待
|
||||
self.logger.debug(f"任务处理中 - 任务ID: {task_id}, 状态码: {result.get('code')}")
|
||||
|
||||
except VectorizerServiceError:
|
||||
# 直接传递服务异常
|
||||
raise
|
||||
except Exception as e:
|
||||
consecutive_failures += 1
|
||||
self.logger.warning(f"轮询检查异常 (第{consecutive_failures}次): {str(e)}")
|
||||
|
||||
if consecutive_failures >= max_consecutive_failures:
|
||||
await self.cleanup_failed_task(task_id)
|
||||
raise VectorizerServiceError(f"连续{consecutive_failures}次轮询异常: {str(e)}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
async def _download_result(self,
|
||||
task_id: str,
|
||||
save_path: str,
|
||||
status_callback: Optional[Callable[[str, dict], None]] = None):
|
||||
"""下载处理结果"""
|
||||
if status_callback:
|
||||
status_callback('downloading', {'message': '正在下载EPS文件...', 'taskid': task_id})
|
||||
|
||||
async with await self.create_http_session() as session:
|
||||
async with session.get(f"{self.base_url}/get_image",
|
||||
params={'taskid': task_id},
|
||||
ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise VectorizerServiceError(f"下载失败 - HTTP {resp.status}: {error_text}")
|
||||
|
||||
# 检查内容长度
|
||||
content_length = resp.headers.get('Content-Length')
|
||||
if content_length and int(content_length) == 0:
|
||||
raise VectorizerServiceError("下载的文件为空")
|
||||
|
||||
# 使用临时文件确保原子写入
|
||||
temp_path = save_path + '.tmp'
|
||||
try:
|
||||
with open(temp_path, 'wb') as f:
|
||||
bytes_written = 0
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
bytes_written += len(chunk)
|
||||
|
||||
# 验证文件大小
|
||||
if bytes_written == 0:
|
||||
raise VectorizerServiceError("下载的文件为空")
|
||||
|
||||
# 原子移动到最终位置
|
||||
os.rename(temp_path, save_path)
|
||||
|
||||
self.logger.info(f"文件下载成功: {save_path}, 大小: {bytes_written} 字节")
|
||||
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
raise VectorizerServiceError(f"文件下载失败: {str(e)}")
|
||||
|
||||
async def get_system_status(self) -> Dict[str, Any]:
|
||||
"""获取系统状态"""
|
||||
return await self.execute_with_retry(
|
||||
self._do_get_system_status,
|
||||
error_context=" - 获取系统状态"
|
||||
)
|
||||
|
||||
async def _do_get_system_status(self) -> Dict[str, Any]:
|
||||
"""执行系统状态查询"""
|
||||
async with await self.create_http_session() as session:
|
||||
async with session.get(f"{self.base_url}/status", ssl=False) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
raise VectorizerServiceError(f"获取系统状态失败 - HTTP {resp.status}: {error_text}")
|
||||
|
||||
return await resp.json()
|
||||
|
||||
async def _do_health_check(self) -> bool:
|
||||
"""执行健康检查"""
|
||||
try:
|
||||
status = await self._do_get_system_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning(f"健康检查失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def cleanup_failed_task(self, task_id: str) -> None:
|
||||
"""清理失败的任务"""
|
||||
try:
|
||||
# 尝试取消任务(如果服务支持)
|
||||
async with await self.create_http_session() as session:
|
||||
async with session.delete(f"{self.base_url}/cancel_task",
|
||||
params={'taskid': task_id},
|
||||
ssl=False) as resp:
|
||||
if resp.status == 200:
|
||||
self.logger.info(f"任务取消成功: {task_id}")
|
||||
else:
|
||||
self.logger.warning(f"任务取消失败: {task_id}, HTTP {resp.status}")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"清理任务异常: {task_id}, 错误: {str(e)}")
|
||||
|
||||
await super().cleanup_failed_task(task_id)
|
||||
|
||||
async def _do_health_check(self) -> bool:
|
||||
"""执行健康检查"""
|
||||
try:
|
||||
async with self.create_http_session() as session:
|
||||
# 测试系统状态端点
|
||||
async with session.get(f"{self.base_url}/system/status", ssl=False) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
self.logger.debug(f"矢量化服务健康检查成功: {data}")
|
||||
return True
|
||||
else:
|
||||
self.logger.warning(f"矢量化服务健康检查失败: HTTP {resp.status}")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.warning(f"矢量化服务健康检查异常: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def test_connection(cls, base_url: Optional[str] = None) -> tuple[bool, str]:
|
||||
"""测试矢量化服务连接"""
|
||||
service = cls(base_url) if base_url else cls()
|
||||
|
||||
try:
|
||||
is_available, message = await service.test_connection()
|
||||
if is_available:
|
||||
# 额外测试系统状态
|
||||
status = await service.get_system_status()
|
||||
return True, f"连接成功: {status}"
|
||||
else:
|
||||
return False, message
|
||||
except Exception as e:
|
||||
logger.error(f"矢量化服务连接失败: {e}")
|
||||
return False, f"连接失败: {str(e)}"
|
||||
|
||||
# 实现 PollingMixin 的抽象方法
|
||||
def _is_task_complete(self, result: Any) -> bool:
|
||||
"""检查任务是否完成"""
|
||||
return result.get('code') == 0
|
||||
|
||||
def _is_task_failed(self, result: Any) -> bool:
|
||||
"""检查任务是否失败"""
|
||||
return result.get('code') == -1
|
||||
|
||||
def _get_error_message(self, result: Any) -> str:
|
||||
"""获取错误消息"""
|
||||
return result.get('message', '未知错误')
|
||||
|
||||
|
||||
# 为了保持向后兼容性,保留原有的简单接口
|
||||
async def vectorize_image(image_path: str,
|
||||
save_eps_path: Optional[str] = None,
|
||||
timeout: int = 1200,
|
||||
progress_callback: Optional[Callable[[str, dict], None]] = None) -> str:
|
||||
"""
|
||||
简单的矢量化接口(向后兼容)
|
||||
|
||||
Args:
|
||||
image_path: 输入图片路径
|
||||
save_eps_path: 输出EPS文件路径
|
||||
timeout: 超时时间
|
||||
progress_callback: 进度回调函数
|
||||
|
||||
Returns:
|
||||
str: EPS文件路径
|
||||
"""
|
||||
service = VectorizerService()
|
||||
return await service.image_to_eps(
|
||||
image_path=image_path,
|
||||
save_eps_path=save_eps_path,
|
||||
timeout=timeout,
|
||||
status_callback=progress_callback
|
||||
)
|
||||
Reference in New Issue
Block a user