主要变更: - 新增 tests/test_ai_chat.py: AI Agent 对话测试工具 - 优化 core/pydantic_ai_agent.py 和 db/chat_log_db.py - 清理归档文件,更新文档 Made-with: Cursor
476 lines
17 KiB
Python
476 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
图片任务数据库管理
|
||
支持客户后续增加需求细节
|
||
"""
|
||
import sqlite3
|
||
import json
|
||
import logging
|
||
from typing import Optional, List, Dict
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
from enum import Enum
|
||
import os
|
||
|
||
logger = logging.getLogger(__name__)
|
||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||
|
||
def _is_mysql() -> bool:
|
||
return _DB_TYPE in ("mysql", "mariadb")
|
||
|
||
def _sql(query: str) -> str:
|
||
return query.replace("?", "%s") if _is_mysql() else query
|
||
|
||
def _now_str() -> str:
|
||
if _is_mysql():
|
||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
return datetime.now().isoformat()
|
||
|
||
class TaskStatus(Enum):
|
||
"""任务状态"""
|
||
PENDING = "pending" # 待付款
|
||
PAID = "paid" # 已付款,待处理
|
||
PROCESSING = "processing" # 处理中
|
||
AWAITING_CONFIRM = "awaiting_confirm" # 已完成,待客户确认
|
||
COMPLETED = "completed" # 已完成
|
||
FAILED = "failed" # 失败
|
||
CANCELLED = "cancelled" # 已取消
|
||
|
||
class ImageTaskManager:
|
||
"""图片任务管理器"""
|
||
|
||
def __init__(self, db_path: str = None):
|
||
if db_path is None:
|
||
db_path = Path(__file__).parent / "image_tasks.db"
|
||
|
||
self.db_path = db_path
|
||
self._init_db()
|
||
logger.info(f"图片任务管理器初始化完成,数据库:{self.db_path}")
|
||
|
||
def _init_db(self):
|
||
"""初始化数据库"""
|
||
if _is_mysql():
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS image_tasks (
|
||
task_id VARCHAR(128) PRIMARY KEY,
|
||
customer_id VARCHAR(128) NOT NULL,
|
||
customer_name VARCHAR(255),
|
||
original_image TEXT NOT NULL,
|
||
operation VARCHAR(64) DEFAULT 'enhance',
|
||
requirements TEXT,
|
||
customer_notes TEXT,
|
||
status VARCHAR(32) DEFAULT 'pending',
|
||
created_at DATETIME,
|
||
paid_at DATETIME,
|
||
started_at DATETIME,
|
||
completed_at DATETIME,
|
||
result_image TEXT,
|
||
error_message TEXT,
|
||
retry_count INT DEFAULT 0,
|
||
acc_id VARCHAR(128),
|
||
acc_type VARCHAR(64) DEFAULT 'AliWorkbench'
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||
''')
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS task_requirement_changes (
|
||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
||
task_id VARCHAR(128) NOT NULL,
|
||
change_type VARCHAR(64),
|
||
old_value TEXT,
|
||
new_value TEXT,
|
||
changed_at DATETIME,
|
||
changed_by VARCHAR(32),
|
||
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
|
||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||
''')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
|
||
conn.commit()
|
||
conn.close()
|
||
else:
|
||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||
conn = sqlite3.connect(self.db_path)
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS image_tasks (
|
||
task_id TEXT PRIMARY KEY,
|
||
customer_id TEXT NOT NULL,
|
||
customer_name TEXT,
|
||
original_image TEXT NOT NULL,
|
||
operation TEXT DEFAULT 'enhance',
|
||
requirements TEXT,
|
||
customer_notes TEXT,
|
||
status TEXT DEFAULT 'pending',
|
||
created_at TEXT,
|
||
paid_at TEXT,
|
||
started_at TEXT,
|
||
completed_at TEXT,
|
||
result_image TEXT,
|
||
error_message TEXT,
|
||
retry_count INTEGER DEFAULT 0,
|
||
acc_id TEXT,
|
||
acc_type TEXT DEFAULT 'AliWorkbench'
|
||
)
|
||
''')
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS task_requirement_changes (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
task_id TEXT NOT NULL,
|
||
change_type TEXT,
|
||
old_value TEXT,
|
||
new_value TEXT,
|
||
changed_at TEXT,
|
||
changed_by TEXT,
|
||
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
|
||
)
|
||
''')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
|
||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
|
||
conn.commit()
|
||
conn.close()
|
||
logger.info("数据库表初始化完成")
|
||
|
||
def _get_conn(self):
|
||
"""获取数据库连接"""
|
||
if _is_mysql():
|
||
import pymysql
|
||
return pymysql.connect(
|
||
host=_MYSQL_HOST,
|
||
port=_MYSQL_PORT,
|
||
user=_MYSQL_USER,
|
||
password=_MYSQL_PASSWORD,
|
||
database=_MYSQL_DATABASE,
|
||
charset="utf8mb4",
|
||
cursorclass=pymysql.cursors.DictCursor,
|
||
autocommit=False,
|
||
)
|
||
conn = sqlite3.connect(self.db_path)
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
def create_task(self, task_id: str, customer_id: str, customer_name: str,
|
||
original_image: str, operation: str = 'enhance',
|
||
requirements: dict = None, acc_id: str = '', acc_type: str = 'AliWorkbench') -> bool:
|
||
"""创建图片任务"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
requirements_json = json.dumps(requirements) if requirements else None
|
||
|
||
cursor.execute(_sql('''
|
||
INSERT INTO image_tasks (
|
||
task_id, customer_id, customer_name, original_image,
|
||
operation, requirements, customer_notes, status,
|
||
created_at, acc_id, acc_type
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
'''), (
|
||
task_id,
|
||
customer_id,
|
||
customer_name,
|
||
original_image,
|
||
operation,
|
||
requirements_json,
|
||
'', # 初始备注为空
|
||
TaskStatus.PENDING.value,
|
||
_now_str(),
|
||
acc_id,
|
||
acc_type
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
logger.info(f"图片任务创建成功:{task_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建图片任务失败:{e}")
|
||
return False
|
||
|
||
def get_task(self, task_id: str) -> Optional[dict]:
|
||
"""查询任务"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(_sql('SELECT * FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||
row = cursor.fetchone()
|
||
conn.close()
|
||
|
||
if row:
|
||
task = dict(row)
|
||
# 解析 JSON 字段
|
||
if task.get('requirements'):
|
||
task['requirements'] = json.loads(task['requirements'])
|
||
return task
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询任务失败:{e}")
|
||
return None
|
||
|
||
def get_customer_tasks(self, customer_id: str, status: str = None) -> List[dict]:
|
||
"""查询客户的任务列表"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
if status:
|
||
cursor.execute(_sql('''
|
||
SELECT * FROM image_tasks
|
||
WHERE customer_id = ? AND status = ?
|
||
ORDER BY created_at DESC
|
||
'''), (customer_id, status))
|
||
else:
|
||
cursor.execute(_sql('''
|
||
SELECT * FROM image_tasks
|
||
WHERE customer_id = ?
|
||
ORDER BY created_at DESC
|
||
'''), (customer_id,))
|
||
|
||
rows = cursor.fetchall()
|
||
conn.close()
|
||
|
||
tasks = []
|
||
for row in rows:
|
||
task = dict(row)
|
||
if task.get('requirements'):
|
||
task['requirements'] = json.loads(task['requirements'])
|
||
tasks.append(task)
|
||
|
||
return tasks
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询客户任务失败:{e}")
|
||
return []
|
||
|
||
def update_status(self, task_id: str, status: TaskStatus):
|
||
"""更新任务状态"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
placeholder = "%s" if _is_mysql() else "?"
|
||
updates = [f'status = {placeholder}']
|
||
params = [status.value]
|
||
|
||
# 根据状态设置时间
|
||
if status == TaskStatus.PAID:
|
||
updates.append(f'paid_at = {placeholder}')
|
||
params.append(_now_str())
|
||
elif status == TaskStatus.PROCESSING:
|
||
updates.append(f'started_at = {placeholder}')
|
||
params.append(_now_str())
|
||
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||
updates.append(f'completed_at = {placeholder}')
|
||
params.append(_now_str())
|
||
|
||
params.append(task_id)
|
||
|
||
cursor.execute(_sql(f'''
|
||
UPDATE image_tasks
|
||
SET {', '.join(updates)}
|
||
WHERE task_id = ?
|
||
'''), params)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
logger.info(f"任务状态更新:{task_id} -> {status.value}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新任务状态失败:{e}")
|
||
|
||
def update_result(self, task_id: str, result_image: str, error_message: str = None):
|
||
"""更新处理结果"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(_sql('''
|
||
UPDATE image_tasks
|
||
SET result_image = ?, error_message = ?
|
||
WHERE task_id = ?
|
||
'''), (result_image, error_message, task_id))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
logger.info(f"任务结果更新:{task_id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"更新任务结果失败:{e}")
|
||
|
||
def add_customer_note(self, task_id: str, note: str, changed_by: str = 'customer') -> bool:
|
||
"""
|
||
客户添加需求备注(支持后续增加细节)
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
note: 备注内容
|
||
changed_by: 修改者(customer/staff)
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
# 获取旧备注
|
||
cursor.execute(_sql('SELECT customer_notes FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||
row = cursor.fetchone()
|
||
old_note = row['customer_notes'] if row else ''
|
||
|
||
# 更新备注
|
||
new_note = f"{old_note}\n[{datetime.now().strftime('%m-%d %H:%M')}] {note}" if old_note else f"[{datetime.now().strftime('%m-%d %H:%M')}] {note}"
|
||
|
||
cursor.execute(_sql('''
|
||
UPDATE image_tasks
|
||
SET customer_notes = ?
|
||
WHERE task_id = ?
|
||
'''), (new_note, task_id))
|
||
|
||
# 记录变更历史
|
||
cursor.execute(_sql('''
|
||
INSERT INTO task_requirement_changes (
|
||
task_id, change_type, old_value, new_value, changed_at, changed_by
|
||
) VALUES (?, ?, ?, ?, ?, ?)
|
||
'''), (
|
||
task_id,
|
||
'add_note',
|
||
old_note or '无',
|
||
note,
|
||
_now_str(),
|
||
changed_by
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
logger.info(f"客户添加备注成功:{task_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"添加客户备注失败:{e}")
|
||
return False
|
||
|
||
def modify_operation(self, task_id: str, new_operation: str, changed_by: str = 'customer') -> bool:
|
||
"""
|
||
修改操作类型(客户后续修改需求)
|
||
|
||
Args:
|
||
task_id: 任务 ID
|
||
new_operation: 新操作类型
|
||
changed_by: 修改者
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
# 获取旧操作
|
||
cursor.execute(_sql('SELECT operation FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||
row = cursor.fetchone()
|
||
old_operation = row['operation'] if row else ''
|
||
|
||
# 更新操作
|
||
cursor.execute(_sql('''
|
||
UPDATE image_tasks
|
||
SET operation = ?
|
||
WHERE task_id = ?
|
||
'''), (new_operation, task_id))
|
||
|
||
# 记录变更历史
|
||
cursor.execute(_sql('''
|
||
INSERT INTO task_requirement_changes (
|
||
task_id, change_type, old_value, new_value, changed_at, changed_by
|
||
) VALUES (?, ?, ?, ?, ?, ?)
|
||
'''), (
|
||
task_id,
|
||
'modify_operation',
|
||
old_operation,
|
||
new_operation,
|
||
_now_str(),
|
||
changed_by
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
logger.info(f"修改操作类型成功:{task_id} -> {new_operation}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"修改操作类型失败:{e}")
|
||
return False
|
||
|
||
def get_requirement_history(self, task_id: str) -> List[dict]:
|
||
"""获取需求变更历史"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(_sql('''
|
||
SELECT * FROM task_requirement_changes
|
||
WHERE task_id = ?
|
||
ORDER BY changed_at DESC
|
||
'''), (task_id,))
|
||
|
||
rows = cursor.fetchall()
|
||
conn.close()
|
||
|
||
return [dict(row) for row in rows]
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询需求历史失败:{e}")
|
||
return []
|
||
|
||
def get_pending_tasks(self) -> List[dict]:
|
||
"""获取所有待处理任务"""
|
||
return self.get_customer_tasks('', 'pending')
|
||
|
||
def increment_retry(self, task_id: str) -> int:
|
||
"""增加重试次数"""
|
||
try:
|
||
conn = self._get_conn()
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(_sql('''
|
||
UPDATE image_tasks
|
||
SET retry_count = retry_count + 1
|
||
WHERE task_id = ?
|
||
'''), (task_id,))
|
||
|
||
cursor.execute(_sql('SELECT retry_count FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||
row = cursor.fetchone()
|
||
conn.close()
|
||
|
||
return row['retry_count'] if row else 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"增加重试次数失败:{e}")
|
||
return 999
|
||
|
||
# 单例
|
||
_task_manager: Optional[ImageTaskManager] = None
|
||
|
||
def get_image_task_manager() -> ImageTaskManager:
|
||
"""获取图片任务管理器单例"""
|
||
global _task_manager
|
||
if _task_manager is None:
|
||
_task_manager = ImageTaskManager()
|
||
return _task_manager
|