# -*- coding: utf-8 -*- """ 天网任务数据模型 """ import sqlite3 import json import logging from datetime import datetime, timedelta from typing import Optional, Dict, List from pathlib import Path from enum import Enum import os from db.chat_log_db import _get_pooled_conn, _return_conn logger = logging.getLogger(__name__) _DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() _MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1") _MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306")) _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") def _sql(query: str) -> str: return query.replace("?", "%s") if _is_mysql() else query def _now_str() -> str: if _is_mysql(): return datetime.now().strftime("%Y-%m-%d %H:%M:%S") return datetime.now().isoformat() class TaskStatus(Enum): """任务状态""" PENDING = "pending" # 待触发 WAITING = "waiting" # 等待触发 RUNNING = "running" # 执行中 COMPLETED = "completed" # 已完成 FAILED = "failed" # 失败 CANCELLED = "cancelled" # 已取消 class TaskPriority(Enum): """任务优先级""" NORMAL = "normal" HIGH = "high" URGENT = "urgent" class _PooledMySQLConn: """包装 pymysql 连接,close 时归还到共享连接池。""" def __init__(self, conn): self._conn = conn def __getattr__(self, name): return getattr(self._conn, name) def close(self): _return_conn(self._conn) class TaskManager: """任务管理器 - SQLite 存储""" def __init__(self, db_path: str = None): if db_path is None: db_path = Path(__file__).parent / "tasks.db" self.db_path = db_path self._init_db() logger.info(f"任务管理器初始化完成,数据库:{self.db_path}") def _init_db(self): """初始化数据库""" if _is_mysql(): conn = self._get_conn() cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks ( task_id VARCHAR(128) PRIMARY KEY, specified_customer_id VARCHAR(128), specified_customer_name VARCHAR(255), type VARCHAR(64) NOT NULL, customer_name VARCHAR(255), customer_id VARCHAR(128), trigger_type VARCHAR(64), trigger_keyword VARCHAR(255), trigger_keywords TEXT, action_type VARCHAR(64), action_file_url TEXT, action_message TEXT, priority VARCHAR(16) DEFAULT 'normal', timeout_hours INT DEFAULT 24, status VARCHAR(32) DEFAULT 'pending', retry_count INT DEFAULT 0, max_retry INT DEFAULT 3, created_at DATETIME, created_by VARCHAR(255), triggered_at DATETIME, completed_at DATETIME, error_message TEXT, result TEXT ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ''') cursor.execute("SHOW INDEX FROM tasks") exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()} if "idx_status" not in exists: cursor.execute('CREATE INDEX idx_status ON tasks(status)') if "idx_customer" not in exists: cursor.execute('CREATE INDEX idx_customer ON tasks(customer_id)') if "idx_created" not in exists: cursor.execute('CREATE INDEX idx_created ON tasks(created_at)') conn.commit() conn.close() else: self.db_path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS tasks ( task_id TEXT PRIMARY KEY, specified_customer_id TEXT, specified_customer_name TEXT, type TEXT NOT NULL, customer_name TEXT, customer_id TEXT, trigger_type TEXT, trigger_keyword TEXT, trigger_keywords TEXT, action_type TEXT, action_file_url TEXT, action_message TEXT, priority TEXT DEFAULT 'normal', timeout_hours INTEGER DEFAULT 24, status TEXT DEFAULT 'pending', retry_count INTEGER DEFAULT 0, max_retry INTEGER DEFAULT 3, created_at TEXT, created_by TEXT, triggered_at TEXT, completed_at TEXT, error_message TEXT, result TEXT ) ''') cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)') cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)') conn.commit() conn.close() logger.info("数据库表初始化完成") def _get_conn(self): """获取数据库连接""" if _is_mysql(): return _PooledMySQLConn(_get_pooled_conn()) conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row return conn def add_task(self, task: dict) -> bool: """添加任务""" try: conn = self._get_conn() cursor = conn.cursor() # 处理 keywords 数组 trigger_keywords = task.get('trigger', {}).get('keywords', []) if isinstance(trigger_keywords, list): trigger_keywords = json.dumps(trigger_keywords) insert_sql = ''' INSERT OR REPLACE INTO tasks ( task_id, specified_customer_id, specified_customer_name, type, customer_name, customer_id, trigger_type, trigger_keyword, trigger_keywords, action_type, action_file_url, action_message, priority, timeout_hours, status, retry_count, max_retry, created_at, created_by ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''' if _is_mysql(): insert_sql = ''' REPLACE INTO tasks ( task_id, specified_customer_id, specified_customer_name, type, customer_name, customer_id, trigger_type, trigger_keyword, trigger_keywords, action_type, action_file_url, action_message, priority, timeout_hours, status, retry_count, max_retry, created_at, created_by ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''' cursor.execute(_sql(insert_sql), ( task.get('task_id'), task.get('customer', {}).get('id'), task.get('customer', {}).get('name'), task.get('type'), task.get('customer', {}).get('name'), task.get('customer', {}).get('id'), task.get('trigger', {}).get('type'), task.get('trigger', {}).get('keyword'), trigger_keywords, task.get('action', {}).get('type'), task.get('action', {}).get('file_url'), task.get('action', {}).get('message'), task.get('priority', 'normal'), task.get('timeout_hours', 24), task.get('status', 'pending'), task.get('retry_count', 0), task.get('max_retry', 3), task.get('created_at', _now_str()), task.get('created_by') )) conn.commit() conn.close() logger.info(f"任务添加成功:{task.get('task_id')}") return True except Exception as e: logger.error(f"添加任务失败:{e}") return False def get_task(self, task_id: str) -> Optional[dict]: """查询任务""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute(_sql('SELECT * FROM tasks WHERE task_id = ?'), (task_id,)) row = cursor.fetchone() conn.close() if row: return dict(row) return None except Exception as e: logger.error(f"查询任务失败:{e}") return None def update_task_status(self, task_id: str, status: TaskStatus, error_message: str = None, result: dict = None): """更新任务状态""" try: conn = self._get_conn() cursor = conn.cursor() placeholder = "%s" if _is_mysql() else "?" updates = [f'status = {placeholder}'] params = [status.value] if status == TaskStatus.RUNNING: updates.append(f'triggered_at = {placeholder}') params.append(_now_str()) if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: updates.append(f'completed_at = {placeholder}') params.append(_now_str()) if error_message: updates.append(f'error_message = {placeholder}') params.append(error_message) if result: updates.append(f'result = {placeholder}') params.append(json.dumps(result)) params.append(task_id) cursor.execute(_sql(f''' UPDATE tasks SET {', '.join(updates)} WHERE task_id = ? '''), params) conn.commit() conn.close() logger.info(f"任务状态更新:{task_id} -> {status.value}") except Exception as e: logger.error(f"更新任务状态失败:{e}") def increment_retry(self, task_id: str) -> int: """增加重试次数,返回当前重试次数""" try: conn = self._get_conn() cursor = conn.cursor() cursor.execute(_sql(''' UPDATE tasks SET retry_count = retry_count + 1 WHERE task_id = ? '''), (task_id,)) cursor.execute(_sql('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?'), (task_id,)) row = cursor.fetchone() conn.close() if row: return row['retry_count'] return 0 except Exception as e: logger.error(f"增加重试次数失败:{e}") return 999 # 超过最大重试次数 def get_pending_tasks(self, customer_id: str = None) -> List[dict]: """获取待触发的任务""" try: conn = self._get_conn() cursor = conn.cursor() if customer_id: cursor.execute(_sql(''' SELECT * FROM tasks WHERE status = 'pending' AND customer_id = ? ORDER BY CASE priority WHEN 'urgent' THEN 1 WHEN 'high' THEN 2 ELSE 3 END, created_at '''), (customer_id,)) else: cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' ORDER BY CASE priority WHEN 'urgent' THEN 1 WHEN 'high' THEN 2 ELSE 3 END, created_at ''') rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] except Exception as e: logger.error(f"获取待触发任务失败:{e}") return [] def get_timeout_tasks(self) -> List[dict]: """获取超时任务""" try: conn = self._get_conn() cursor = conn.cursor() if _is_mysql(): cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' AND created_at < DATE_SUB(NOW(), INTERVAL timeout_hours HOUR) ''') else: cursor.execute(''' SELECT * FROM tasks WHERE status = 'pending' AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now') ''') rows = cursor.fetchall() conn.close() return [dict(row) for row in rows] except Exception as e: logger.error(f"获取超时任务失败:{e}") return [] def cancel_task(self, task_id: str, reason: str = None): """取消任务""" self.update_task_status(task_id, TaskStatus.CANCELLED, error_message=reason) logger.info(f"任务已取消:{task_id}") def get_statistics(self) -> dict: """获取任务统计""" try: conn = self._get_conn() cursor = conn.cursor() stats = {} for status in TaskStatus: cursor.execute(_sql('SELECT COUNT(*) FROM tasks WHERE status = ?'), (status.value,)) stats[status.value] = cursor.fetchone()[0] conn.close() return stats except Exception as e: logger.error(f"获取统计失败:{e}") return {} # 单例 _task_manager: Optional[TaskManager] = None def get_task_manager() -> TaskManager: """获取任务管理器单例""" global _task_manager if _task_manager is None: _task_manager = TaskManager() return _task_manager