Files
tw/db/task_db/task_model.py

407 lines
15 KiB
Python

# -*- coding: utf-8 -*-
"""
天网任务数据模型
"""
import sqlite3
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, List
from pathlib import Path
from enum import Enum
import os
logger = logging.getLogger(__name__)
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _now_str() -> str:
if _is_mysql():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return datetime.now().isoformat()
class TaskStatus(Enum):
"""任务状态"""
PENDING = "pending" # 待触发
WAITING = "waiting" # 等待触发
RUNNING = "running" # 执行中
COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败
CANCELLED = "cancelled" # 已取消
class TaskPriority(Enum):
"""任务优先级"""
NORMAL = "normal"
HIGH = "high"
URGENT = "urgent"
class TaskManager:
"""任务管理器 - SQLite 存储"""
def __init__(self, db_path: str = None):
if db_path is None:
db_path = Path(__file__).parent / "tasks.db"
self.db_path = db_path
self._init_db()
logger.info(f"任务管理器初始化完成,数据库:{self.db_path}")
def _init_db(self):
"""初始化数据库"""
if _is_mysql():
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id VARCHAR(128) PRIMARY KEY,
specified_customer_id VARCHAR(128),
specified_customer_name VARCHAR(255),
type VARCHAR(64) NOT NULL,
customer_name VARCHAR(255),
customer_id VARCHAR(128),
trigger_type VARCHAR(64),
trigger_keyword VARCHAR(255),
trigger_keywords TEXT,
action_type VARCHAR(64),
action_file_url TEXT,
action_message TEXT,
priority VARCHAR(16) DEFAULT 'normal',
timeout_hours INT DEFAULT 24,
status VARCHAR(32) DEFAULT 'pending',
retry_count INT DEFAULT 0,
max_retry INT DEFAULT 3,
created_at DATETIME,
created_by VARCHAR(255),
triggered_at DATETIME,
completed_at DATETIME,
error_message TEXT,
result TEXT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
''')
cursor.execute("SHOW INDEX FROM tasks")
exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()}
if "idx_status" not in exists:
cursor.execute('CREATE INDEX idx_status ON tasks(status)')
if "idx_customer" not in exists:
cursor.execute('CREATE INDEX idx_customer ON tasks(customer_id)')
if "idx_created" not in exists:
cursor.execute('CREATE INDEX idx_created ON tasks(created_at)')
conn.commit()
conn.close()
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
specified_customer_id TEXT,
specified_customer_name TEXT,
type TEXT NOT NULL,
customer_name TEXT,
customer_id TEXT,
trigger_type TEXT,
trigger_keyword TEXT,
trigger_keywords TEXT,
action_type TEXT,
action_file_url TEXT,
action_message TEXT,
priority TEXT DEFAULT 'normal',
timeout_hours INTEGER DEFAULT 24,
status TEXT DEFAULT 'pending',
retry_count INTEGER DEFAULT 0,
max_retry INTEGER DEFAULT 3,
created_at TEXT,
created_by TEXT,
triggered_at TEXT,
completed_at TEXT,
error_message TEXT,
result TEXT
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
conn.commit()
conn.close()
logger.info("数据库表初始化完成")
def _get_conn(self):
"""获取数据库连接"""
if _is_mysql():
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def add_task(self, task: dict) -> bool:
"""添加任务"""
try:
conn = self._get_conn()
cursor = conn.cursor()
# 处理 keywords 数组
trigger_keywords = task.get('trigger', {}).get('keywords', [])
if isinstance(trigger_keywords, list):
trigger_keywords = json.dumps(trigger_keywords)
insert_sql = '''
INSERT OR REPLACE INTO tasks (
task_id, specified_customer_id, specified_customer_name,
type, customer_name, customer_id,
trigger_type, trigger_keyword, trigger_keywords,
action_type, action_file_url, action_message,
priority, timeout_hours, status, retry_count, max_retry,
created_at, created_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
'''
if _is_mysql():
insert_sql = '''
REPLACE INTO tasks (
task_id, specified_customer_id, specified_customer_name,
type, customer_name, customer_id,
trigger_type, trigger_keyword, trigger_keywords,
action_type, action_file_url, action_message,
priority, timeout_hours, status, retry_count, max_retry,
created_at, created_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
'''
cursor.execute(_sql(insert_sql), (
task.get('task_id'),
task.get('customer', {}).get('id'),
task.get('customer', {}).get('name'),
task.get('type'),
task.get('customer', {}).get('name'),
task.get('customer', {}).get('id'),
task.get('trigger', {}).get('type'),
task.get('trigger', {}).get('keyword'),
trigger_keywords,
task.get('action', {}).get('type'),
task.get('action', {}).get('file_url'),
task.get('action', {}).get('message'),
task.get('priority', 'normal'),
task.get('timeout_hours', 24),
task.get('status', 'pending'),
task.get('retry_count', 0),
task.get('max_retry', 3),
task.get('created_at', _now_str()),
task.get('created_by')
))
conn.commit()
conn.close()
logger.info(f"任务添加成功:{task.get('task_id')}")
return True
except Exception as e:
logger.error(f"添加任务失败:{e}")
return False
def get_task(self, task_id: str) -> Optional[dict]:
"""查询任务"""
try:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute(_sql('SELECT * FROM tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
conn.close()
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"查询任务失败:{e}")
return None
def update_task_status(self, task_id: str, status: TaskStatus, error_message: str = None, result: dict = None):
"""更新任务状态"""
try:
conn = self._get_conn()
cursor = conn.cursor()
placeholder = "%s" if _is_mysql() else "?"
updates = [f'status = {placeholder}']
params = [status.value]
if status == TaskStatus.RUNNING:
updates.append(f'triggered_at = {placeholder}')
params.append(_now_str())
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
updates.append(f'completed_at = {placeholder}')
params.append(_now_str())
if error_message:
updates.append(f'error_message = {placeholder}')
params.append(error_message)
if result:
updates.append(f'result = {placeholder}')
params.append(json.dumps(result))
params.append(task_id)
cursor.execute(_sql(f'''
UPDATE tasks
SET {', '.join(updates)}
WHERE task_id = ?
'''), params)
conn.commit()
conn.close()
logger.info(f"任务状态更新:{task_id} -> {status.value}")
except Exception as e:
logger.error(f"更新任务状态失败:{e}")
def increment_retry(self, task_id: str) -> int:
"""增加重试次数,返回当前重试次数"""
try:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute(_sql('''
UPDATE tasks
SET retry_count = retry_count + 1
WHERE task_id = ?
'''), (task_id,))
cursor.execute(_sql('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
conn.close()
if row:
return row['retry_count']
return 0
except Exception as e:
logger.error(f"增加重试次数失败:{e}")
return 999 # 超过最大重试次数
def get_pending_tasks(self, customer_id: str = None) -> List[dict]:
"""获取待触发的任务"""
try:
conn = self._get_conn()
cursor = conn.cursor()
if customer_id:
cursor.execute(_sql('''
SELECT * FROM tasks
WHERE status = 'pending'
AND customer_id = ?
ORDER BY
CASE priority
WHEN 'urgent' THEN 1
WHEN 'high' THEN 2
ELSE 3
END,
created_at
'''), (customer_id,))
else:
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
ORDER BY
CASE priority
WHEN 'urgent' THEN 1
WHEN 'high' THEN 2
ELSE 3
END,
created_at
''')
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"获取待触发任务失败:{e}")
return []
def get_timeout_tasks(self) -> List[dict]:
"""获取超时任务"""
try:
conn = self._get_conn()
cursor = conn.cursor()
if _is_mysql():
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
AND created_at < DATE_SUB(NOW(), INTERVAL timeout_hours HOUR)
''')
else:
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
''')
rows = cursor.fetchall()
conn.close()
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"获取超时任务失败:{e}")
return []
def cancel_task(self, task_id: str, reason: str = None):
"""取消任务"""
self.update_task_status(task_id, TaskStatus.CANCELLED, error_message=reason)
logger.info(f"任务已取消:{task_id}")
def get_statistics(self) -> dict:
"""获取任务统计"""
try:
conn = self._get_conn()
cursor = conn.cursor()
stats = {}
for status in TaskStatus:
cursor.execute(_sql('SELECT COUNT(*) FROM tasks WHERE status = ?'), (status.value,))
stats[status.value] = cursor.fetchone()[0]
conn.close()
return stats
except Exception as e:
logger.error(f"获取统计失败:{e}")
return {}
# 单例
_task_manager: Optional[TaskManager] = None
def get_task_manager() -> TaskManager:
"""获取任务管理器单例"""
global _task_manager
if _task_manager is None:
_task_manager = TaskManager()
return _task_manager