feat: 添加 AI Agent 对话测试工具 + 代码优化
主要变更: - 新增 tests/test_ai_chat.py: AI Agent 对话测试工具 - 优化 core/pydantic_ai_agent.py 和 db/chat_log_db.py - 清理归档文件,更新文档 Made-with: Cursor
This commit is contained in:
@@ -9,8 +9,26 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
def _now_str() -> str:
|
||||
if _is_mysql():
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
return datetime.now().isoformat()
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务状态"""
|
||||
@@ -40,51 +58,93 @@ class TaskManager:
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库"""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建任务表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
specified_customer_id TEXT,
|
||||
specified_customer_name TEXT,
|
||||
type TEXT NOT NULL,
|
||||
customer_name TEXT,
|
||||
customer_id TEXT,
|
||||
trigger_type TEXT,
|
||||
trigger_keyword TEXT,
|
||||
trigger_keywords TEXT, -- JSON array
|
||||
action_type TEXT,
|
||||
action_file_url TEXT,
|
||||
action_message TEXT,
|
||||
priority TEXT DEFAULT 'normal',
|
||||
timeout_hours INTEGER DEFAULT 24,
|
||||
status TEXT DEFAULT 'pending',
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retry INTEGER DEFAULT 3,
|
||||
created_at TEXT,
|
||||
created_by TEXT,
|
||||
triggered_at TEXT,
|
||||
completed_at TEXT,
|
||||
error_message TEXT,
|
||||
result TEXT -- JSON
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建索引
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
if _is_mysql():
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id VARCHAR(128) PRIMARY KEY,
|
||||
specified_customer_id VARCHAR(128),
|
||||
specified_customer_name VARCHAR(255),
|
||||
type VARCHAR(64) NOT NULL,
|
||||
customer_name VARCHAR(255),
|
||||
customer_id VARCHAR(128),
|
||||
trigger_type VARCHAR(64),
|
||||
trigger_keyword VARCHAR(255),
|
||||
trigger_keywords TEXT,
|
||||
action_type VARCHAR(64),
|
||||
action_file_url TEXT,
|
||||
action_message TEXT,
|
||||
priority VARCHAR(16) DEFAULT 'normal',
|
||||
timeout_hours INT DEFAULT 24,
|
||||
status VARCHAR(32) DEFAULT 'pending',
|
||||
retry_count INT DEFAULT 0,
|
||||
max_retry INT DEFAULT 3,
|
||||
created_at DATETIME,
|
||||
created_by VARCHAR(255),
|
||||
triggered_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error_message TEXT,
|
||||
result TEXT
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
else:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
specified_customer_id TEXT,
|
||||
specified_customer_name TEXT,
|
||||
type TEXT NOT NULL,
|
||||
customer_name TEXT,
|
||||
customer_id TEXT,
|
||||
trigger_type TEXT,
|
||||
trigger_keyword TEXT,
|
||||
trigger_keywords TEXT,
|
||||
action_type TEXT,
|
||||
action_file_url TEXT,
|
||||
action_message TEXT,
|
||||
priority TEXT DEFAULT 'normal',
|
||||
timeout_hours INTEGER DEFAULT 24,
|
||||
status TEXT DEFAULT 'pending',
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retry INTEGER DEFAULT 3,
|
||||
created_at TEXT,
|
||||
created_by TEXT,
|
||||
triggered_at TEXT,
|
||||
completed_at TEXT,
|
||||
error_message TEXT,
|
||||
result TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("数据库表初始化完成")
|
||||
|
||||
def _get_conn(self):
|
||||
"""获取数据库连接"""
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
@@ -100,7 +160,7 @@ class TaskManager:
|
||||
if isinstance(trigger_keywords, list):
|
||||
trigger_keywords = json.dumps(trigger_keywords)
|
||||
|
||||
cursor.execute('''
|
||||
insert_sql = '''
|
||||
INSERT OR REPLACE INTO tasks (
|
||||
task_id, specified_customer_id, specified_customer_name,
|
||||
type, customer_name, customer_id,
|
||||
@@ -109,7 +169,19 @@ class TaskManager:
|
||||
priority, timeout_hours, status, retry_count, max_retry,
|
||||
created_at, created_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
'''
|
||||
if _is_mysql():
|
||||
insert_sql = '''
|
||||
REPLACE INTO tasks (
|
||||
task_id, specified_customer_id, specified_customer_name,
|
||||
type, customer_name, customer_id,
|
||||
trigger_type, trigger_keyword, trigger_keywords,
|
||||
action_type, action_file_url, action_message,
|
||||
priority, timeout_hours, status, retry_count, max_retry,
|
||||
created_at, created_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
'''
|
||||
cursor.execute(_sql(insert_sql), (
|
||||
task.get('task_id'),
|
||||
task.get('customer', {}).get('id'),
|
||||
task.get('customer', {}).get('name'),
|
||||
@@ -127,7 +199,7 @@ class TaskManager:
|
||||
task.get('status', 'pending'),
|
||||
task.get('retry_count', 0),
|
||||
task.get('max_retry', 3),
|
||||
task.get('created_at', datetime.now().isoformat()),
|
||||
task.get('created_at', _now_str()),
|
||||
task.get('created_by')
|
||||
))
|
||||
|
||||
@@ -147,7 +219,7 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT * FROM tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -165,32 +237,33 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
updates = ['status = ?']
|
||||
placeholder = "%s" if _is_mysql() else "?"
|
||||
updates = [f'status = {placeholder}']
|
||||
params = [status.value]
|
||||
|
||||
if status == TaskStatus.RUNNING:
|
||||
updates.append('triggered_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'triggered_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
|
||||
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
updates.append('completed_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'completed_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
|
||||
if error_message:
|
||||
updates.append('error_message = ?')
|
||||
updates.append(f'error_message = {placeholder}')
|
||||
params.append(error_message)
|
||||
|
||||
if result:
|
||||
updates.append('result = ?')
|
||||
updates.append(f'result = {placeholder}')
|
||||
params.append(json.dumps(result))
|
||||
|
||||
params.append(task_id)
|
||||
|
||||
cursor.execute(f'''
|
||||
cursor.execute(_sql(f'''
|
||||
UPDATE tasks
|
||||
SET {', '.join(updates)}
|
||||
WHERE task_id = ?
|
||||
''', params)
|
||||
'''), params)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -206,13 +279,13 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
UPDATE tasks
|
||||
SET retry_count = retry_count + 1
|
||||
WHERE task_id = ?
|
||||
''', (task_id,))
|
||||
'''), (task_id,))
|
||||
|
||||
cursor.execute('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -231,7 +304,7 @@ class TaskManager:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if customer_id:
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND customer_id = ?
|
||||
@@ -242,7 +315,7 @@ class TaskManager:
|
||||
ELSE 3
|
||||
END,
|
||||
created_at
|
||||
''', (customer_id,))
|
||||
'''), (customer_id,))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
@@ -271,11 +344,18 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
|
||||
''')
|
||||
if _is_mysql():
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND created_at < DATE_SUB(NOW(), INTERVAL timeout_hours HOUR)
|
||||
''')
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
|
||||
''')
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
@@ -299,7 +379,7 @@ class TaskManager:
|
||||
|
||||
stats = {}
|
||||
for status in TaskStatus:
|
||||
cursor.execute('SELECT COUNT(*) FROM tasks WHERE status = ?', (status.value,))
|
||||
cursor.execute(_sql('SELECT COUNT(*) FROM tasks WHERE status = ?'), (status.value,))
|
||||
stats[status.value] = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user