433 lines
18 KiB
Python
433 lines
18 KiB
Python
"""
|
||
矢量化服务模块 - 使用统一异常处理机制
|
||
"""
|
||
|
||
import aiohttp
|
||
import time
|
||
import urllib3
|
||
from typing import Callable, Optional, Dict, Any
|
||
import logging
|
||
import os
|
||
import asyncio
|
||
from pathlib import Path
|
||
|
||
# 导入基础服务类
|
||
from utils.service_base import (
|
||
BaseService, PollingMixin, ServiceError, ServiceTimeoutError,
|
||
ServiceNetworkError, RetryConfig, TimeoutConfig
|
||
)
|
||
|
||
# 禁用SSL警告
|
||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class VectorizerServiceError(ServiceError):
|
||
"""矢量化服务特定异常"""
|
||
pass
|
||
|
||
|
||
class VectorizerService(BaseService, PollingMixin):
|
||
"""矢量化服务类 - 继承统一异常处理机制"""
|
||
|
||
# def __init__(self, base_url: str = "https://frp-dad.com:50529"):
|
||
def __init__(self, base_url: str = "https://127.0.0.1:8090"):
|
||
# 配置重试和超时
|
||
retry_config = RetryConfig(
|
||
max_retries=3,
|
||
base_delay=2.0,
|
||
max_delay=30.0
|
||
)
|
||
|
||
timeout_config = TimeoutConfig(
|
||
connection_timeout=60.0,
|
||
read_timeout=240.0,
|
||
total_timeout=1200.0
|
||
)
|
||
|
||
super().__init__(
|
||
name="VectorizerService",
|
||
base_url=base_url,
|
||
retry_config=retry_config,
|
||
timeout_config=timeout_config
|
||
)
|
||
|
||
async def image_to_eps(self,
|
||
image_path: str,
|
||
save_eps_path: Optional[str] = None,
|
||
timeout: int = 1200,
|
||
poll_interval: float = 2.0,
|
||
status_callback: Optional[Callable[[str, dict], None]] = None) -> str:
|
||
"""
|
||
将图片转换为EPS矢量文件
|
||
|
||
Args:
|
||
image_path: 输入图片路径
|
||
save_eps_path: 输出EPS文件路径(可选)
|
||
timeout: 最大等待时间(秒)
|
||
poll_interval: 轮询间隔(秒)
|
||
status_callback: 状态回调函数
|
||
|
||
Returns:
|
||
str: EPS文件保存路径
|
||
|
||
Raises:
|
||
VectorizerServiceError: 矢量化服务异常
|
||
ServiceTimeoutError: 超时异常
|
||
ServiceNetworkError: 网络异常
|
||
"""
|
||
# 验证输入文件
|
||
if not os.path.exists(image_path):
|
||
raise VectorizerServiceError(f"输入图片文件不存在: {image_path}")
|
||
|
||
# 设置输出路径
|
||
if save_eps_path is None:
|
||
save_eps_path = os.path.splitext(image_path)[0] + '.eps'
|
||
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(save_eps_path)
|
||
if output_dir:
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
try:
|
||
# 1. 上传图片
|
||
task_id = await self.execute_with_retry(
|
||
self._upload_image,
|
||
image_path,
|
||
status_callback,
|
||
error_context=" - 上传图片"
|
||
)
|
||
|
||
# 2. 轮询等待处理完成
|
||
await self.execute_with_retry(
|
||
self._wait_for_processing,
|
||
task_id,
|
||
timeout,
|
||
poll_interval,
|
||
status_callback,
|
||
error_context=" - 等待处理完成"
|
||
)
|
||
|
||
# 3. 下载结果文件
|
||
await self.execute_with_retry(
|
||
self._download_result,
|
||
task_id,
|
||
save_eps_path,
|
||
status_callback,
|
||
error_context=" - 下载结果文件"
|
||
)
|
||
|
||
# 验证输出文件
|
||
if not os.path.exists(save_eps_path) or os.path.getsize(save_eps_path) == 0:
|
||
raise VectorizerServiceError(f"输出文件创建失败或为空: {save_eps_path}")
|
||
|
||
self.logger.info(f"矢量化转换成功: {image_path} -> {save_eps_path}")
|
||
|
||
if status_callback:
|
||
status_callback('finished', {
|
||
'message': '转换完成!',
|
||
'input_path': image_path,
|
||
'output_path': save_eps_path
|
||
})
|
||
|
||
return save_eps_path
|
||
|
||
except (ServiceTimeoutError, ServiceNetworkError, VectorizerServiceError):
|
||
# 直接传递这些异常
|
||
if status_callback:
|
||
status_callback('error', {'message': '处理失败,请稍后重试'})
|
||
raise
|
||
except Exception as e:
|
||
error_msg = f"矢量化转换失败: {str(e)}"
|
||
self.logger.error(error_msg)
|
||
if status_callback:
|
||
status_callback('error', {'message': error_msg})
|
||
raise VectorizerServiceError(error_msg)
|
||
|
||
async def _upload_image(self,
|
||
image_path: str,
|
||
status_callback: Optional[Callable[[str, dict], None]] = None) -> str:
|
||
"""上传图片到矢量化服务"""
|
||
if status_callback:
|
||
status_callback('uploading', {'message': '正在上传图片...', 'image_path': image_path})
|
||
|
||
async with await self.create_http_session() as session:
|
||
with open(image_path, 'rb') as f:
|
||
data = aiohttp.FormData()
|
||
data.add_field('file', f, filename=os.path.basename(image_path))
|
||
|
||
async with session.post(f"{self.base_url}/add_task", data=data, ssl=False) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise VectorizerServiceError(f"上传失败 - HTTP {resp.status}: {error_text}")
|
||
|
||
response_data = await resp.json()
|
||
|
||
if response_data.get('code') != 0:
|
||
error_msg = response_data.get('message', '未知错误')
|
||
raise VectorizerServiceError(f"上传失败: {error_msg}")
|
||
|
||
task_id = response_data.get('id') or response_data.get('taskid')
|
||
if not task_id:
|
||
raise VectorizerServiceError("上传失败,未获取到任务ID")
|
||
|
||
self.logger.info(f"图片上传成功,任务ID: {task_id}")
|
||
|
||
if status_callback:
|
||
status_callback('uploaded', {
|
||
'message': '图片上传成功,开始处理...',
|
||
'taskid': task_id
|
||
})
|
||
|
||
return task_id
|
||
|
||
async def _wait_for_processing(self,
|
||
task_id: str,
|
||
timeout: int,
|
||
poll_interval: float,
|
||
status_callback: Optional[Callable[[str, dict], None]] = None):
|
||
"""等待处理完成"""
|
||
start_time = time.time()
|
||
poll_count = 0
|
||
consecutive_failures = 0
|
||
max_consecutive_failures = 5
|
||
|
||
async with await self.create_http_session() as session:
|
||
while True:
|
||
poll_count += 1
|
||
elapsed_time = time.time() - start_time
|
||
|
||
# 检查超时
|
||
if elapsed_time > timeout:
|
||
await self.cleanup_failed_task(task_id)
|
||
raise ServiceTimeoutError(f"处理超时 (超过{timeout}秒) - 任务ID: {task_id}")
|
||
|
||
if status_callback:
|
||
progress = min(90, int(90 * elapsed_time / timeout))
|
||
status_callback('processing', {
|
||
'message': f'正在处理中... (第{poll_count}次检查)',
|
||
'taskid': task_id,
|
||
'elapsed_time': elapsed_time,
|
||
'poll_count': poll_count,
|
||
'progress': progress
|
||
})
|
||
|
||
try:
|
||
async with session.get(f"{self.base_url}/try_get",
|
||
params={'taskid': task_id},
|
||
ssl=False) as resp:
|
||
if resp.status != 200:
|
||
consecutive_failures += 1
|
||
error_text = await resp.text()
|
||
self.logger.warning(f"状态查询失败 - HTTP {resp.status}: {error_text}")
|
||
|
||
if consecutive_failures >= max_consecutive_failures:
|
||
await self.cleanup_failed_task(task_id)
|
||
raise VectorizerServiceError(f"连续{consecutive_failures}次状态查询失败")
|
||
|
||
await asyncio.sleep(poll_interval * 2) # 失败时等待更长时间
|
||
continue
|
||
|
||
# 重置失败计数
|
||
consecutive_failures = 0
|
||
|
||
result = await resp.json()
|
||
|
||
if result.get('code') == 0:
|
||
# 处理完成
|
||
self.logger.info(f"任务处理完成 - 任务ID: {task_id}, 耗时: {elapsed_time:.2f}秒")
|
||
if status_callback:
|
||
status_callback('completed', {
|
||
'message': '处理完成,准备下载...',
|
||
'taskid': task_id,
|
||
'total_time': elapsed_time,
|
||
'poll_count': poll_count
|
||
})
|
||
return
|
||
|
||
elif result.get('code') == -1:
|
||
# 处理失败
|
||
error_msg = result.get('message', '处理失败')
|
||
await self.cleanup_failed_task(task_id)
|
||
raise VectorizerServiceError(f"服务器处理失败: {error_msg}")
|
||
|
||
# 其他状态码,继续等待
|
||
self.logger.debug(f"任务处理中 - 任务ID: {task_id}, 状态码: {result.get('code')}")
|
||
|
||
except VectorizerServiceError:
|
||
# 直接传递服务异常
|
||
raise
|
||
except Exception as e:
|
||
consecutive_failures += 1
|
||
self.logger.warning(f"轮询检查异常 (第{consecutive_failures}次): {str(e)}")
|
||
|
||
if consecutive_failures >= max_consecutive_failures:
|
||
await self.cleanup_failed_task(task_id)
|
||
raise VectorizerServiceError(f"连续{consecutive_failures}次轮询异常: {str(e)}")
|
||
|
||
await asyncio.sleep(poll_interval)
|
||
|
||
async def _download_result(self,
|
||
task_id: str,
|
||
save_path: str,
|
||
status_callback: Optional[Callable[[str, dict], None]] = None):
|
||
"""下载处理结果"""
|
||
if status_callback:
|
||
status_callback('downloading', {'message': '正在下载EPS文件...', 'taskid': task_id})
|
||
|
||
async with await self.create_http_session() as session:
|
||
async with session.get(f"{self.base_url}/get_image",
|
||
params={'taskid': task_id},
|
||
ssl=False) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise VectorizerServiceError(f"下载失败 - HTTP {resp.status}: {error_text}")
|
||
|
||
# 检查内容长度
|
||
content_length = resp.headers.get('Content-Length')
|
||
if content_length and int(content_length) == 0:
|
||
raise VectorizerServiceError("下载的文件为空")
|
||
|
||
# 使用临时文件确保原子写入
|
||
temp_path = save_path + '.tmp'
|
||
try:
|
||
with open(temp_path, 'wb') as f:
|
||
bytes_written = 0
|
||
while True:
|
||
chunk = await resp.content.read(8192)
|
||
if not chunk:
|
||
break
|
||
f.write(chunk)
|
||
bytes_written += len(chunk)
|
||
|
||
# 验证文件大小
|
||
if bytes_written == 0:
|
||
raise VectorizerServiceError("下载的文件为空")
|
||
|
||
# 原子移动到最终位置
|
||
os.rename(temp_path, save_path)
|
||
|
||
self.logger.info(f"文件下载成功: {save_path}, 大小: {bytes_written} 字节")
|
||
|
||
except Exception as e:
|
||
# 清理临时文件
|
||
if os.path.exists(temp_path):
|
||
os.remove(temp_path)
|
||
raise VectorizerServiceError(f"文件下载失败: {str(e)}")
|
||
|
||
async def get_system_status(self) -> Dict[str, Any]:
|
||
"""获取系统状态"""
|
||
return await self.execute_with_retry(
|
||
self._do_get_system_status,
|
||
error_context=" - 获取系统状态"
|
||
)
|
||
|
||
async def _do_get_system_status(self) -> Dict[str, Any]:
|
||
"""执行系统状态查询"""
|
||
async with await self.create_http_session() as session:
|
||
async with session.get(f"{self.base_url}/status", ssl=False) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
raise VectorizerServiceError(f"获取系统状态失败 - HTTP {resp.status}: {error_text}")
|
||
|
||
return await resp.json()
|
||
|
||
async def _do_health_check(self) -> bool:
|
||
"""执行健康检查"""
|
||
try:
|
||
status = await self._do_get_system_status()
|
||
return True
|
||
except Exception as e:
|
||
self.logger.warning(f"健康检查失败: {str(e)}")
|
||
return False
|
||
|
||
async def cleanup_failed_task(self, task_id: str) -> None:
|
||
"""清理失败的任务"""
|
||
try:
|
||
# 尝试取消任务(如果服务支持)
|
||
async with await self.create_http_session() as session:
|
||
async with session.delete(f"{self.base_url}/cancel_task",
|
||
params={'taskid': task_id},
|
||
ssl=False) as resp:
|
||
if resp.status == 200:
|
||
self.logger.info(f"任务取消成功: {task_id}")
|
||
else:
|
||
self.logger.warning(f"任务取消失败: {task_id}, HTTP {resp.status}")
|
||
except Exception as e:
|
||
self.logger.warning(f"清理任务异常: {task_id}, 错误: {str(e)}")
|
||
|
||
await super().cleanup_failed_task(task_id)
|
||
|
||
async def _do_health_check(self) -> bool:
|
||
"""执行健康检查"""
|
||
try:
|
||
async with self.create_http_session() as session:
|
||
# 测试系统状态端点
|
||
async with session.get(f"{self.base_url}/system/status", ssl=False) as resp:
|
||
if resp.status == 200:
|
||
data = await resp.json()
|
||
self.logger.debug(f"矢量化服务健康检查成功: {data}")
|
||
return True
|
||
else:
|
||
self.logger.warning(f"矢量化服务健康检查失败: HTTP {resp.status}")
|
||
return False
|
||
except Exception as e:
|
||
self.logger.warning(f"矢量化服务健康检查异常: {e}")
|
||
return False
|
||
|
||
@classmethod
|
||
async def test_connection(cls, base_url: Optional[str] = None) -> tuple[bool, str]:
|
||
"""测试矢量化服务连接"""
|
||
service = cls(base_url) if base_url else cls()
|
||
|
||
try:
|
||
is_available, message = await service.test_connection()
|
||
if is_available:
|
||
# 额外测试系统状态
|
||
status = await service.get_system_status()
|
||
return True, f"连接成功: {status}"
|
||
else:
|
||
return False, message
|
||
except Exception as e:
|
||
logger.error(f"矢量化服务连接失败: {e}")
|
||
return False, f"连接失败: {str(e)}"
|
||
|
||
# 实现 PollingMixin 的抽象方法
|
||
def _is_task_complete(self, result: Any) -> bool:
|
||
"""检查任务是否完成"""
|
||
return result.get('code') == 0
|
||
|
||
def _is_task_failed(self, result: Any) -> bool:
|
||
"""检查任务是否失败"""
|
||
return result.get('code') == -1
|
||
|
||
def _get_error_message(self, result: Any) -> str:
|
||
"""获取错误消息"""
|
||
return result.get('message', '未知错误')
|
||
|
||
|
||
# 为了保持向后兼容性,保留原有的简单接口
|
||
async def vectorize_image(image_path: str,
|
||
save_eps_path: Optional[str] = None,
|
||
timeout: int = 1200,
|
||
progress_callback: Optional[Callable[[str, dict], None]] = None) -> str:
|
||
"""
|
||
简单的矢量化接口(向后兼容)
|
||
|
||
Args:
|
||
image_path: 输入图片路径
|
||
save_eps_path: 输出EPS文件路径
|
||
timeout: 超时时间
|
||
progress_callback: 进度回调函数
|
||
|
||
Returns:
|
||
str: EPS文件路径
|
||
"""
|
||
service = VectorizerService()
|
||
return await service.image_to_eps(
|
||
image_path=image_path,
|
||
save_eps_path=save_eps_path,
|
||
timeout=timeout,
|
||
status_callback=progress_callback
|
||
) |