# -*- coding: utf-8 -*- """ 天网任务数据模型 """ import sqlite3 import json import logging from datetime import datetime, timedelta from typing import Optional, Dict, List from pathlib import Path from enum import Enum import os logger = logging.getLogger(__name__) _DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() _MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1") _MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306")) _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") def _sql(query: str) -> str: return query.replace("?", "%s") if _is_mysql() else query def _now_str() -> str: if _is_mysql(): return datetime.now().strftime("%Y-%m-%d %H:%M:%S") return datetime.now().isoformat() class TaskStatus(Enum): """任务状态""" PENDING = "pending" # 待触发 WAITING = "waiting" # 等待触发 RUNNING = "running" # 执行中 COMPLETED = "completed" # 已完成 FAILED = "failed" # 失败 CANCELLED = "cancelled" # 已取消 class TaskPriority(Enum): """任务优先级""" NORMAL = "normal" HIGH = "high" URGENT = "urgent" class TaskManager: """任务管理器 - SQLite 存储""" def __init__(self, db_path: str = None): if db_path is None: db_path = Path(__file__).parent / "tasks.db" self.db_path = db_path self._init_db() logger.info(f"任务管理器初始化完成,数据库:{self.db_path}") def _init_db(self): """初始化数据库""" if _is_mysql(): conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks ( task_id VARCHAR(128) PRIMARY KEY, specified_customer_id VARCHAR(128), specified_customer_name VARCHAR(255), type VARCHAR(64) NOT NULL, customer_name VARCHAR(255), customer_id VARCHAR(128), trigger_type VARCHAR(64), trigger_keyword VARCHAR(255), trigger_keywords TEXT, action_type VARCHAR(64), action_file_url TEXT, action_message TEXT, priority VARCHAR(16) DEFAULT 'normal', timeout_hours INT DEFAULT 24, status VARCHAR(32) DEFAULT 'pending', retry_count INT DEFAULT 0, max_retry INT DEFAULT 3, created_at DATETIME, created_by VARCHAR(255), triggered_at DATETIME, completed_at DATETIME, error_message TEXT, result TEXT ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ''') cursor.execute("SHOW INDEX FROM tasks") exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()} if "idx_status" not in exists: cursor.execute('CREATE INDEX idx_status ON tasks(status)') if "idx_customer" not in exists: cursor.execute('CREATE INDEX idx_customer ON tasks(customer_id)') if "idx_created" not in exists: cursor.execute('CREATE INDEX idx_created ON tasks(created_at)') conn.commit() conn.close() else: self.db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, specified_customer_id TEXT, specified_customer_name TEXT, type TEXT NOT NULL, customer_name TEXT, customer_id TEXT, trigger_type TEXT, trigger_keyword TEXT, trigger_keywords TEXT, action_type TEXT, action_file_url TEXT, action_message TEXT, priority TEXT DEFAULT 'normal', timeout_hours INTEGER DEFAULT 24, status TEXT DEFAULT 'pending', retry_count INTEGER DEFAULT 0, max_retry INTEGER DEFAULT 3, created_at TEXT, created_by TEXT, triggered_at TEXT, completed_at TEXT, error_message TEXT, result TEXT ) ''') cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)') conn.commit() conn.close() logger.info("数据库表初始化完成") def _get_conn(self): """获取数据库连接""" if _is_mysql(): import pymysql return pymysql.connect( host=_MYSQL_HOST, port=_MYSQL_PORT, user=_MYSQL_USER, password=_MYSQL_PASSWORD, database=_MYSQL_DATABASE, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor, autocommit=False, ) conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def add_task(self, task: dict) -> bool: """添加任务""" try: conn = self._get_conn() cursor = conn.cursor() # 处理 keywords 数组 trigger_keywords = task.get('trigger', {}).get('keywords', []) if isinstance(trigger_keywords, list): trigger_keywords = json.dumps(trigger_keywords) insert_sql = ''' INSERT OR REPLACE INTO tasks ( task_id, specified_customer_id, specified_customer_name, type, customer_name, customer_id, trigger_type, trigger_keyword, trigger_keywords, action_type, action_file_url, action_message, priority, timeout_hours, status, retry_count, max_retry, created_at, created_by ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''' if _is_mysql(): insert_sql = ''' REPLACE INTO tasks ( task_id, specified_customer_id, specified_customer_name, type, customer_name, customer_id, trigger_type, trigger_keyword, trigger_keywords, action_type, action_file_url, action_message, priority, timeout_hours, status, retry_count, max_retry, created_at, created_by ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''' cursor.execute(_sql(insert_sql), ( task.get('task_id'), task.get('customer', {}).get('id'), task.get('customer', {}).get('name'), task.get('type'), task.get('customer', {}).get('name'), task.get('customer', {}).get('id'), task.get('trigger', {}).get('type'), task.get('trigger', {}).get('keyword'), trigger_keywords, task.get('action', {}).get('type'), task.get('action', {}).get('file_url'), task.get('action', {}).get('message'), task.get('priority', 'normal'), task.get('timeout_hours', 24), task.get('status', 'pending'), task.get('retry_count', 0), task.get('max_retry', 3), task.get('created_at', _now_str()), task.get('created_by') )) conn.commit() conn.close() logger.info(f"任务添加成功:{task.get('task_id')}") return True except Exception as e: logger.error(f"添加任务失败:{e}") return False def get_task(self, task_id: str) -> Optional[dict]: """查询任务""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute(_sql('SELECT * FROM tasks WHERE task_id = ?'), (task_id,)) row = cursor.fetchone() conn.close() if row: return dict(row) return None except Exception as e: logger.error(f"查询任务失败:{e}") return None def update_task_status(self, task_id: str, status: TaskStatus, error_message: str = None, result: dict = None): """更新任务状态""" try: conn = self._get_conn() cursor = conn.cursor() placeholder = "%s" if _is_mysql() else "?" updates = [f'status = {placeholder}'] params = [status.value] if status == TaskStatus.RUNNING: updates.append(f'triggered_at = {placeholder}') params.append(_now_str()) if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: updates.append(f'completed_at = {placeholder}') params.append(_now_str()) if error_message: updates.append(f'error_message = {placeholder}') params.append(error_message) if result: updates.append(f'result = {placeholder}') params.append(json.dumps(result)) params.append(task_id) cursor.execute(_sql(f''' UPDATE tasks SET {', '.join(updates)} WHERE task_id = ? '''), params) conn.commit() conn.close() logger.info(f"任务状态更新:{task_id} -> {status.value}") except Exception as e: logger.error(f"更新任务状态失败:{e}") def increment_retry(self, task_id: str) -> int: """增加重试次数,返回当前重试次数""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute(_sql(''' UPDATE tasks SET retry_count = retry_count + 1 WHERE task_id = ? '''), (task_id,)) cursor.execute(_sql('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?'), (task_id,)) row = cursor.fetchone() conn.close() if row: return row['retry_count'] return 0 except Exception as e: logger.error(f"增加重试次数失败:{e}") return 999 # 超过最大重试次数 def get_pending_tasks(self, customer_id: str = None) -> List[dict]: """获取待触发的任务""" try: conn = self._get_conn() cursor = conn.cursor() if customer_id: cursor.execute(_sql(''' SELECT * FROM tasks WHERE status = 'pending' AND customer_id = ? ORDER BY CASE priority WHEN 'urgent' THEN 1 WHEN 'high' THEN 2 ELSE 3 END, created_at '''), (customer_id,)) else: cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' ORDER BY CASE priority WHEN 'urgent' THEN 1 WHEN 'high' THEN 2 ELSE 3 END, created_at ''') rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] except Exception as e: logger.error(f"获取待触发任务失败:{e}") return [] def get_timeout_tasks(self) -> List[dict]: """获取超时任务""" try: conn = self._get_conn() cursor = conn.cursor() if _is_mysql(): cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' AND created_at < DATE_SUB(NOW(), INTERVAL timeout_hours HOUR) ''') else: cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now') ''') rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] except Exception as e: logger.error(f"获取超时任务失败:{e}") return [] def cancel_task(self, task_id: str, reason: str = None): """取消任务""" self.update_task_status(task_id, TaskStatus.CANCELLED, error_message=reason) logger.info(f"任务已取消:{task_id}") def get_statistics(self) -> dict: """获取任务统计""" try: conn = self._get_conn() cursor = conn.cursor() stats = {} for status in TaskStatus: cursor.execute(_sql('SELECT COUNT(*) FROM tasks WHERE status = ?'), (status.value,)) stats[status.value] = cursor.fetchone()[0] conn.close() return stats except Exception as e: logger.error(f"获取统计失败:{e}") return {} # 单例 _task_manager: Optional[TaskManager] = None def get_task_manager() -> TaskManager: """获取任务管理器单例""" global _task_manager if _task_manager is None: _task_manager = TaskManager() return _task_manager