# -*- coding: utf-8 -*- """ 天网任务数据模型 """ import sqlite3 import json import logging from datetime import datetime, timedelta from typing import Optional, Dict, List from pathlib import Path from enum import Enum logger = logging.getLogger(__name__) class TaskStatus(Enum): """任务状态""" PENDING = "pending" # 待触发 WAITING = "waiting" # 等待触发 RUNNING = "running" # 执行中 COMPLETED = "completed" # 已完成 FAILED = "failed" # 失败 CANCELLED = "cancelled" # 已取消 class TaskPriority(Enum): """任务优先级""" NORMAL = "normal" HIGH = "high" URGENT = "urgent" class TaskManager: """任务管理器 - SQLite 存储""" def __init__(self, db_path: str = None): if db_path is None: db_path = Path(__file__).parent / "tasks.db" self.db_path = db_path self._init_db() logger.info(f"任务管理器初始化完成,数据库:{self.db_path}") def _init_db(self): """初始化数据库""" self.db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # 创建任务表 cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, specified_customer_id TEXT, specified_customer_name TEXT, type TEXT NOT NULL, customer_name TEXT, customer_id TEXT, trigger_type TEXT, trigger_keyword TEXT, trigger_keywords TEXT, -- JSON array action_type TEXT, action_file_url TEXT, action_message TEXT, priority TEXT DEFAULT 'normal', timeout_hours INTEGER DEFAULT 24, status TEXT DEFAULT 'pending', retry_count INTEGER DEFAULT 0, max_retry INTEGER DEFAULT 3, created_at TEXT, created_by TEXT, triggered_at TEXT, completed_at TEXT, error_message TEXT, result TEXT -- JSON ) ''') # 创建索引 cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)') conn.commit() conn.close() logger.info("数据库表初始化完成") def _get_conn(self): """获取数据库连接""" conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def add_task(self, task: dict) -> bool: """添加任务""" try: conn = self._get_conn() cursor = conn.cursor() # 处理 keywords 数组 trigger_keywords = task.get('trigger', {}).get('keywords', []) if isinstance(trigger_keywords, list): trigger_keywords = json.dumps(trigger_keywords) cursor.execute(''' INSERT OR REPLACE INTO tasks ( task_id, specified_customer_id, specified_customer_name, type, customer_name, customer_id, trigger_type, trigger_keyword, trigger_keywords, action_type, action_file_url, action_message, priority, timeout_hours, status, retry_count, max_retry, created_at, created_by ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( task.get('task_id'), task.get('customer', {}).get('id'), task.get('customer', {}).get('name'), task.get('type'), task.get('customer', {}).get('name'), task.get('customer', {}).get('id'), task.get('trigger', {}).get('type'), task.get('trigger', {}).get('keyword'), trigger_keywords, task.get('action', {}).get('type'), task.get('action', {}).get('file_url'), task.get('action', {}).get('message'), task.get('priority', 'normal'), task.get('timeout_hours', 24), task.get('status', 'pending'), task.get('retry_count', 0), task.get('max_retry', 3), task.get('created_at', datetime.now().isoformat()), task.get('created_by') )) conn.commit() conn.close() logger.info(f"任务添加成功:{task.get('task_id')}") return True except Exception as e: logger.error(f"添加任务失败:{e}") return False def get_task(self, task_id: str) -> Optional[dict]: """查询任务""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,)) row = cursor.fetchone() conn.close() if row: return dict(row) return None except Exception as e: logger.error(f"查询任务失败:{e}") return None def update_task_status(self, task_id: str, status: TaskStatus, error_message: str = None, result: dict = None): """更新任务状态""" try: conn = self._get_conn() cursor = conn.cursor() updates = ['status = ?'] params = [status.value] if status == TaskStatus.RUNNING: updates.append('triggered_at = ?') params.append(datetime.now().isoformat()) if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: updates.append('completed_at = ?') params.append(datetime.now().isoformat()) if error_message: updates.append('error_message = ?') params.append(error_message) if result: updates.append('result = ?') params.append(json.dumps(result)) params.append(task_id) cursor.execute(f''' UPDATE tasks SET {', '.join(updates)} WHERE task_id = ? ''', params) conn.commit() conn.close() logger.info(f"任务状态更新:{task_id} -> {status.value}") except Exception as e: logger.error(f"更新任务状态失败:{e}") def increment_retry(self, task_id: str) -> int: """增加重试次数,返回当前重试次数""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' UPDATE tasks SET retry_count = retry_count + 1 WHERE task_id = ? ''', (task_id,)) cursor.execute('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?', (task_id,)) row = cursor.fetchone() conn.close() if row: return row['retry_count'] return 0 except Exception as e: logger.error(f"增加重试次数失败:{e}") return 999 # 超过最大重试次数 def get_pending_tasks(self, customer_id: str = None) -> List[dict]: """获取待触发的任务""" try: conn = self._get_conn() cursor = conn.cursor() if customer_id: cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' AND customer_id = ? ORDER BY CASE priority WHEN 'urgent' THEN 1 WHEN 'high' THEN 2 ELSE 3 END, created_at ''', (customer_id,)) else: cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' ORDER BY CASE priority WHEN 'urgent' THEN 1 WHEN 'high' THEN 2 ELSE 3 END, created_at ''') rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] except Exception as e: logger.error(f"获取待触发任务失败:{e}") return [] def get_timeout_tasks(self) -> List[dict]: """获取超时任务""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now') ''') rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] except Exception as e: logger.error(f"获取超时任务失败:{e}") return [] def cancel_task(self, task_id: str, reason: str = None): """取消任务""" self.update_task_status(task_id, TaskStatus.CANCELLED, error_message=reason) logger.info(f"任务已取消:{task_id}") def get_statistics(self) -> dict: """获取任务统计""" try: conn = self._get_conn() cursor = conn.cursor() stats = {} for status in TaskStatus: cursor.execute('SELECT COUNT(*) FROM tasks WHERE status = ?', (status.value,)) stats[status.value] = cursor.fetchone()[0] conn.close() return stats except Exception as e: logger.error(f"获取统计失败:{e}") return {} # 单例 _task_manager: Optional[TaskManager] = None def get_task_manager() -> TaskManager: """获取任务管理器单例""" global _task_manager if _task_manager is None: _task_manager = TaskManager() return _task_manager