106 lines
3.1 KiB
Python
106 lines
3.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
服务基类 - 供 VectorizerService 等异步服务使用
|
||
"""
|
||
import logging
|
||
import asyncio
|
||
from dataclasses import dataclass
|
||
from typing import Any, Callable, TypeVar, Optional
|
||
import aiohttp
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
@dataclass
|
||
class RetryConfig:
|
||
max_retries: int = 3
|
||
base_delay: float = 2.0
|
||
max_delay: float = 30.0
|
||
|
||
|
||
@dataclass
|
||
class TimeoutConfig:
|
||
connection_timeout: float = 60.0
|
||
read_timeout: float = 240.0
|
||
total_timeout: float = 1200.0
|
||
|
||
|
||
class ServiceError(Exception):
|
||
"""服务异常基类"""
|
||
pass
|
||
|
||
|
||
class ServiceTimeoutError(ServiceError):
|
||
"""超时异常"""
|
||
pass
|
||
|
||
|
||
class ServiceNetworkError(ServiceError):
|
||
"""网络异常"""
|
||
pass
|
||
|
||
|
||
class PollingMixin:
|
||
"""轮询混入 - 子类需实现 _is_task_complete, _is_task_failed, _get_error_message"""
|
||
def _is_task_complete(self, result: Any) -> bool:
|
||
raise NotImplementedError
|
||
def _is_task_failed(self, result: Any) -> bool:
|
||
raise NotImplementedError
|
||
def _get_error_message(self, result: Any) -> str:
|
||
raise NotImplementedError
|
||
|
||
|
||
class BaseService:
|
||
"""异步服务基类"""
|
||
def __init__(
|
||
self,
|
||
name: str = "BaseService",
|
||
base_url: str = "",
|
||
retry_config: Optional[RetryConfig] = None,
|
||
timeout_config: Optional[TimeoutConfig] = None,
|
||
):
|
||
self.name = name
|
||
self.base_url = base_url.rstrip("/")
|
||
self.retry_config = retry_config or RetryConfig()
|
||
self.timeout_config = timeout_config or TimeoutConfig()
|
||
self.logger = logging.getLogger(name)
|
||
self._session: Optional[aiohttp.ClientSession] = None
|
||
|
||
async def create_http_session(self) -> aiohttp.ClientSession:
|
||
"""创建 aiohttp 会话(不验证 SSL)"""
|
||
connector = aiohttp.TCPConnector(ssl=False)
|
||
timeout = aiohttp.ClientTimeout(
|
||
connect=self.timeout_config.connection_timeout,
|
||
total=self.timeout_config.total_timeout,
|
||
)
|
||
return aiohttp.ClientSession(connector=connector, timeout=timeout)
|
||
|
||
async def execute_with_retry(
|
||
self,
|
||
func: Callable[..., Any],
|
||
*args,
|
||
error_context: str = "",
|
||
**kwargs,
|
||
) -> Any:
|
||
"""带重试的执行"""
|
||
last_error = None
|
||
for attempt in range(self.retry_config.max_retries + 1):
|
||
try:
|
||
return await func(*args, **kwargs)
|
||
except (ServiceTimeoutError, ServiceNetworkError):
|
||
raise
|
||
except Exception as e:
|
||
last_error = e
|
||
self.logger.warning(f"第{attempt + 1}次尝试失败{error_context}: {e}")
|
||
if attempt < self.retry_config.max_retries:
|
||
delay = min(
|
||
self.retry_config.base_delay * (2 ** attempt),
|
||
self.retry_config.max_delay,
|
||
)
|
||
await asyncio.sleep(delay)
|
||
raise last_error or ServiceError("未知错误")
|
||
|
||
async def cleanup_failed_task(self, task_id: str) -> None:
|
||
"""清理失败任务(子类可覆盖)"""
|
||
self.logger.debug(f"清理任务: {task_id}")
|