feat: 添加 AI Agent 对话测试工具 + 代码优化

主要变更:

- 新增 tests/test_ai_chat.py: AI Agent 对话测试工具

- 优化 core/pydantic_ai_agent.py 和 db/chat_log_db.py

- 清理归档文件,更新文档

Made-with: Cursor
This commit is contained in:
2026-02-28 16:19:35 +08:00
parent a6c42d505a
commit c39840fe15
49 changed files with 2453 additions and 8556 deletions

View File

@@ -9,9 +9,33 @@ from datetime import datetime
from typing import List, Dict, Optional
_DB_PATH = os.path.join(os.path.dirname(__file__), "chat_log_db", "chats.db")
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _get_conn() -> sqlite3.Connection:
if _is_mysql():
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -21,27 +45,44 @@ def _get_conn() -> sqlite3.Connection:
def init_db():
"""建表(首次运行时自动调用)"""
with _get_conn() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id TEXT NOT NULL,
customer_name TEXT DEFAULT '',
acc_id TEXT DEFAULT '',
platform TEXT DEFAULT '',
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
message TEXT NOT NULL,
msg_type INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
# 兼容旧表:若缺少 acc_id 列则补上(必须在创建该列索引之前)
try:
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
except Exception:
pass
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
if _is_mysql():
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_logs (
id INTEGER PRIMARY KEY AUTO_INCREMENT,
customer_id VARCHAR(128) NOT NULL,
customer_name VARCHAR(255) DEFAULT '',
acc_id VARCHAR(128) DEFAULT '',
platform VARCHAR(64) DEFAULT '',
direction VARCHAR(8) NOT NULL,
message TEXT NOT NULL,
msg_type INTEGER DEFAULT 0,
timestamp DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id TEXT NOT NULL,
customer_name TEXT DEFAULT '',
acc_id TEXT DEFAULT '',
platform TEXT DEFAULT '',
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
message TEXT NOT NULL,
msg_type INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
try:
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
except Exception:
pass
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
conn.commit()
@@ -63,9 +104,9 @@ def log_message(
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with _get_conn() as conn:
conn.execute(
"INSERT INTO chat_logs "
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
"VALUES (?,?,?,?,?,?,?,?)",
_sql("INSERT INTO chat_logs "
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
"VALUES (?,?,?,?,?,?,?,?)"),
(customer_id, customer_name, acc_id, platform, direction, message, msg_type, ts),
)
conn.commit()
@@ -77,6 +118,19 @@ def get_customers(limit: int = 100) -> List[Dict]:
"""返回所有有记录的客户列表(按最新消息时间排序)"""
with _get_conn() as conn:
rows = conn.execute("""
SELECT
customer_id,
MAX(customer_name) AS customer_name,
MAX(platform) AS platform,
COUNT(*) AS total_msgs,
SUM(direction='in') AS recv,
SUM(direction='out') AS sent,
MAX(timestamp) AS last_time
FROM chat_logs
GROUP BY customer_id
ORDER BY last_time DESC
LIMIT %s
""" if _is_mysql() else """
SELECT
customer_id,
MAX(customer_name) AS customer_name,
@@ -96,13 +150,13 @@ def get_customers(limit: int = 100) -> List[Dict]:
def get_conversation(customer_id: str, limit: int = 200) -> List[Dict]:
"""返回某客户的全部对话记录(按时间升序)"""
with _get_conn() as conn:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT id, direction, message, msg_type, timestamp, acc_id
FROM chat_logs
WHERE customer_id = ?
ORDER BY timestamp ASC, id ASC
LIMIT ?
""", (customer_id, limit)).fetchall()
"""), (customer_id, limit)).fetchall()
return [dict(r) for r in rows]
@@ -110,21 +164,21 @@ def get_recent_conversation(customer_id: str, acc_id: str = "", limit: int = 10)
"""返回某客户近期对话(同店铺),用于企微推送保持连贯"""
with _get_conn() as conn:
if acc_id:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT id, direction, message, timestamp, acc_id
FROM chat_logs
WHERE customer_id = ? AND acc_id = ?
ORDER BY id DESC
LIMIT ?
""", (customer_id, acc_id, limit)).fetchall()
"""), (customer_id, acc_id, limit)).fetchall()
else:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT id, direction, message, timestamp, acc_id
FROM chat_logs
WHERE customer_id = ?
ORDER BY id DESC
LIMIT ?
""", (customer_id, limit)).fetchall()
"""), (customer_id, limit)).fetchall()
out = [dict(r) for r in reversed(rows)]
return out
@@ -133,12 +187,12 @@ def get_conversation_today(customer_id: str) -> List[Dict]:
"""返回某客户今天的对话"""
today = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT id, direction, message, msg_type, timestamp
FROM chat_logs
WHERE customer_id = ? AND timestamp LIKE ?
ORDER BY timestamp ASC, id ASC
""", (customer_id, f"{today}%")).fetchall()
"""), (customer_id, f"{today}%")).fetchall()
return [dict(r) for r in rows]
@@ -151,7 +205,7 @@ def get_daily_stats(date: str = "") -> List[Dict]:
if not date:
date = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT
acc_id,
platform,
@@ -165,7 +219,7 @@ def get_daily_stats(date: str = "") -> List[Dict]:
WHERE timestamp LIKE ?
GROUP BY acc_id
ORDER BY unique_customers DESC
""", (f"{date}%",)).fetchall()
"""), (f"{date}%",)).fetchall()
return [dict(r) for r in rows]
@@ -176,22 +230,39 @@ def get_daily_conversations(date: str = "") -> List[Dict]:
if not date:
date = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
rows = conn.execute("""
SELECT
acc_id,
customer_id,
MAX(customer_name) AS customer_name,
COUNT(*) AS msg_count,
GROUP_CONCAT(
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
ELSE '客:' || SUBSTR(message,1,40) END,
' | '
) AS snippet
FROM chat_logs
WHERE timestamp LIKE ?
GROUP BY acc_id, customer_id
ORDER BY acc_id, MAX(timestamp) DESC
""", (f"{date}%",)).fetchall()
if _is_mysql():
rows = conn.execute(_sql("""
SELECT
acc_id,
customer_id,
MAX(customer_name) AS customer_name,
COUNT(*) AS msg_count,
GROUP_CONCAT(
CONCAT(CASE WHEN direction='in' THEN '买:' ELSE '客:' END, LEFT(message,40))
SEPARATOR ' | '
) AS snippet
FROM chat_logs
WHERE timestamp LIKE ?
GROUP BY acc_id, customer_id
ORDER BY acc_id, MAX(timestamp) DESC
"""), (f"{date}%",)).fetchall()
else:
rows = conn.execute(_sql("""
SELECT
acc_id,
customer_id,
MAX(customer_name) AS customer_name,
COUNT(*) AS msg_count,
GROUP_CONCAT(
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
ELSE '客:' || SUBSTR(message,1,40) END,
' | '
) AS snippet
FROM chat_logs
WHERE timestamp LIKE ?
GROUP BY acc_id, customer_id
ORDER BY acc_id, MAX(timestamp) DESC
"""), (f"{date}%",)).fetchall()
return [dict(r) for r in rows]
@@ -199,18 +270,28 @@ def search_messages(keyword: str, customer_id: Optional[str] = None, limit: int
"""全文搜索消息"""
if customer_id:
with _get_conn() as conn:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT customer_id, customer_name, direction, message, timestamp
FROM chat_logs
WHERE customer_id = ? AND message LIKE ?
ORDER BY timestamp DESC LIMIT ?
""", (customer_id, f"%{keyword}%", limit)).fetchall()
"""), (customer_id, f"%{keyword}%", limit)).fetchall()
else:
with _get_conn() as conn:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT customer_id, customer_name, direction, message, timestamp
FROM chat_logs
WHERE message LIKE ?
ORDER BY timestamp DESC LIMIT ?
""", (f"%{keyword}%", limit)).fetchall()
"""), (f"%{keyword}%", limit)).fetchall()
return [dict(r) for r in rows]
def get_latest_messages(limit: int = 20) -> List[Dict]:
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT id, customer_id, customer_name, direction, message, timestamp
FROM chat_logs
ORDER BY id DESC LIMIT ?
"""), (limit,)).fetchall()
return [dict(r) for r in rows]

View File

@@ -8,9 +8,33 @@ from datetime import datetime
from typing import List, Dict, Optional
_DB_PATH = os.path.join(os.path.dirname(__file__), "deal_outcome_db", "outcomes.db")
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _get_conn() -> sqlite3.Connection:
if _is_mysql():
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -19,26 +43,48 @@ def _get_conn() -> sqlite3.Connection:
def _init_db():
with _get_conn() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS deal_outcomes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id TEXT NOT NULL,
customer_name TEXT DEFAULT '',
acc_id TEXT DEFAULT '',
platform TEXT DEFAULT '',
date TEXT NOT NULL,
outcome TEXT NOT NULL CHECK(outcome IN ('成交','未成交')),
reason TEXT DEFAULT '',
order_id TEXT DEFAULT '',
amount REAL DEFAULT 0,
discount_given INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
if _is_mysql():
conn.execute("""
CREATE TABLE IF NOT EXISTS deal_outcomes (
id INTEGER PRIMARY KEY AUTO_INCREMENT,
customer_id VARCHAR(128) NOT NULL,
customer_name VARCHAR(255) DEFAULT '',
acc_id VARCHAR(128) DEFAULT '',
platform VARCHAR(64) DEFAULT '',
date DATE NOT NULL,
outcome VARCHAR(16) NOT NULL,
reason TEXT,
order_id VARCHAR(128) DEFAULT '',
amount REAL DEFAULT 0,
discount_given INTEGER DEFAULT 0,
timestamp DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS deal_outcomes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id TEXT NOT NULL,
customer_name TEXT DEFAULT '',
acc_id TEXT DEFAULT '',
platform TEXT DEFAULT '',
date TEXT NOT NULL,
outcome TEXT NOT NULL CHECK(outcome IN ('成交','未成交')),
reason TEXT DEFAULT '',
order_id TEXT DEFAULT '',
amount REAL DEFAULT 0,
discount_given INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
conn.commit()
@@ -61,10 +107,10 @@ def record_deal(
date = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
conn.execute(
"""INSERT INTO deal_outcomes
_sql("""INSERT INTO deal_outcomes
(customer_id, customer_name, acc_id, platform, date, outcome, reason,
order_id, amount, discount_given, timestamp)
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
VALUES (?,?,?,?,?,?,?,?,?,?,?)"""),
(
customer_id,
customer_name or "",
@@ -88,13 +134,13 @@ def get_daily_outcomes(date: str = "") -> List[Dict]:
date = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
rows = conn.execute(
"""
_sql("""
SELECT customer_id, customer_name, acc_id, outcome, reason,
order_id, amount, discount_given, timestamp
FROM deal_outcomes
WHERE date = ?
ORDER BY timestamp ASC
""",
"""),
(date,),
).fetchall()
return [dict(r) for r in rows]
@@ -131,19 +177,19 @@ def export_for_analysis(start_date: str = "", end_date: str = "") -> List[Dict]:
with _get_conn() as conn:
if start_date and end_date:
rows = conn.execute(
"""SELECT * FROM deal_outcomes
_sql("""SELECT * FROM deal_outcomes
WHERE date BETWEEN ? AND ?
ORDER BY date, timestamp""",
ORDER BY date, timestamp"""),
(start_date, end_date),
).fetchall()
elif start_date:
rows = conn.execute(
"""SELECT * FROM deal_outcomes WHERE date >= ? ORDER BY date, timestamp""",
_sql("""SELECT * FROM deal_outcomes WHERE date >= ? ORDER BY date, timestamp"""),
(start_date,),
).fetchall()
elif end_date:
rows = conn.execute(
"""SELECT * FROM deal_outcomes WHERE date <= ? ORDER BY date, timestamp""",
_sql("""SELECT * FROM deal_outcomes WHERE date <= ? ORDER BY date, timestamp"""),
(end_date,),
).fetchall()
else:

View File

@@ -10,9 +10,33 @@ import os
from typing import Optional
_DB_PATH = os.path.join(os.path.dirname(__file__), "designer_roster_db", "roster.db")
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _get_conn() -> sqlite3.Connection:
if _is_mysql():
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -21,35 +45,66 @@ def _get_conn() -> sqlite3.Connection:
def init_db():
with _get_conn() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS designers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
wechat_user_id TEXT UNIQUE NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS designer_shops (
designer_id INTEGER NOT NULL,
shop_id TEXT NOT NULL,
group_id TEXT NOT NULL,
PRIMARY KEY (designer_id, shop_id),
FOREIGN KEY (designer_id) REFERENCES designers(id)
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS designer_online (
wechat_user_id TEXT PRIMARY KEY,
is_online INTEGER NOT NULL DEFAULT 0,
updated_at TEXT
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS round_robin (
shop_id TEXT PRIMARY KEY,
last_index INTEGER NOT NULL DEFAULT 0
)
""")
if _is_mysql():
conn.execute("""
CREATE TABLE IF NOT EXISTS designers (
id INTEGER PRIMARY KEY AUTO_INCREMENT,
name VARCHAR(255) NOT NULL,
wechat_user_id VARCHAR(128) UNIQUE NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS designer_shops (
designer_id INTEGER NOT NULL,
shop_id VARCHAR(128) NOT NULL,
group_id VARCHAR(128) NOT NULL,
PRIMARY KEY (designer_id, shop_id),
FOREIGN KEY (designer_id) REFERENCES designers(id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS designer_online (
wechat_user_id VARCHAR(128) PRIMARY KEY,
is_online INTEGER NOT NULL DEFAULT 0,
updated_at DATETIME
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS round_robin (
shop_id VARCHAR(128) PRIMARY KEY,
last_index INTEGER NOT NULL DEFAULT 0
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS designers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
wechat_user_id TEXT UNIQUE NOT NULL
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS designer_shops (
designer_id INTEGER NOT NULL,
shop_id TEXT NOT NULL,
group_id TEXT NOT NULL,
PRIMARY KEY (designer_id, shop_id),
FOREIGN KEY (designer_id) REFERENCES designers(id)
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS designer_online (
wechat_user_id TEXT PRIMARY KEY,
is_online INTEGER NOT NULL DEFAULT 0,
updated_at TEXT
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS round_robin (
shop_id TEXT PRIMARY KEY,
last_index INTEGER NOT NULL DEFAULT 0
)
""")
conn.commit()
@@ -61,22 +116,34 @@ init_db()
def add_designer(name: str, wechat_user_id: str) -> int:
"""添加设计师,返回 id"""
with _get_conn() as conn:
conn.execute(
"INSERT OR IGNORE INTO designers (name, wechat_user_id) VALUES (?, ?)",
(name, wechat_user_id),
)
if _is_mysql():
conn.execute(
"INSERT IGNORE INTO designers (name, wechat_user_id) VALUES (%s, %s)",
(name, wechat_user_id),
)
else:
conn.execute(
"INSERT OR IGNORE INTO designers (name, wechat_user_id) VALUES (?, ?)",
(name, wechat_user_id),
)
conn.commit()
row = conn.execute("SELECT id FROM designers WHERE wechat_user_id = ?", (wechat_user_id,)).fetchone()
row = conn.execute(_sql("SELECT id FROM designers WHERE wechat_user_id = ?"), (wechat_user_id,)).fetchone()
return row["id"] if row else 0
def set_designer_shop(designer_id: int, shop_id: str, group_id: str):
"""设置设计师在某店铺的分组 ID同一设计师不同店铺不同 group_id"""
with _get_conn() as conn:
conn.execute(
"INSERT OR REPLACE INTO designer_shops (designer_id, shop_id, group_id) VALUES (?, ?, ?)",
(designer_id, shop_id, group_id),
)
if _is_mysql():
conn.execute(
"REPLACE INTO designer_shops (designer_id, shop_id, group_id) VALUES (%s, %s, %s)",
(designer_id, shop_id, group_id),
)
else:
conn.execute(
"INSERT OR REPLACE INTO designer_shops (designer_id, shop_id, group_id) VALUES (?, ?, ?)",
(designer_id, shop_id, group_id),
)
conn.commit()
@@ -85,10 +152,16 @@ def update_online(wechat_user_id: str, is_online: bool):
from datetime import datetime
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with _get_conn() as conn:
conn.execute(
"INSERT OR REPLACE INTO designer_online (wechat_user_id, is_online, updated_at) VALUES (?, ?, ?)",
(wechat_user_id, 1 if is_online else 0, ts),
)
if _is_mysql():
conn.execute(
"REPLACE INTO designer_online (wechat_user_id, is_online, updated_at) VALUES (%s, %s, %s)",
(wechat_user_id, 1 if is_online else 0, ts),
)
else:
conn.execute(
"INSERT OR REPLACE INTO designer_online (wechat_user_id, is_online, updated_at) VALUES (?, ?, ?)",
(wechat_user_id, 1 if is_online else 0, ts),
)
conn.commit()
@@ -101,26 +174,32 @@ def get_transfer_group_for_shop(shop_id: str) -> Optional[str]:
无人在线则返回 None。
"""
with _get_conn() as conn:
rows = conn.execute("""
rows = conn.execute(_sql("""
SELECT d.wechat_user_id, ds.group_id
FROM designer_shops ds
JOIN designers d ON d.id = ds.designer_id
JOIN designer_online o ON o.wechat_user_id = d.wechat_user_id AND o.is_online = 1
WHERE ds.shop_id = ?
""", (shop_id,)).fetchall()
"""), (shop_id,)).fetchall()
if not rows:
return None
with _get_conn() as conn:
rr = conn.execute("SELECT last_index FROM round_robin WHERE shop_id = ?", (shop_id,)).fetchone()
rr = conn.execute(_sql("SELECT last_index FROM round_robin WHERE shop_id = ?"), (shop_id,)).fetchone()
last = rr["last_index"] if rr else 0
idx = last % len(rows)
chosen = rows[idx]
conn.execute(
"INSERT OR REPLACE INTO round_robin (shop_id, last_index) VALUES (?, ?)",
(shop_id, idx + 1),
)
if _is_mysql():
conn.execute(
"REPLACE INTO round_robin (shop_id, last_index) VALUES (%s, %s)",
(shop_id, idx + 1),
)
else:
conn.execute(
"INSERT OR REPLACE INTO round_robin (shop_id, last_index) VALUES (?, ?)",
(shop_id, idx + 1),
)
conn.commit()
return chosen["group_id"]
@@ -142,11 +221,11 @@ def list_designers():
result = []
for d in designers:
shops = conn.execute(
"SELECT shop_id, group_id FROM designer_shops WHERE designer_id = ?",
_sql("SELECT shop_id, group_id FROM designer_shops WHERE designer_id = ?"),
(d["id"],),
).fetchall()
online = conn.execute(
"SELECT is_online FROM designer_online WHERE wechat_user_id = ?",
_sql("SELECT is_online FROM designer_online WHERE wechat_user_id = ?"),
(d["wechat_user_id"],),
).fetchone()
result.append({

View File

@@ -10,8 +10,26 @@ from typing import Optional, List, Dict
from pathlib import Path
from datetime import datetime
from enum import Enum
import os
logger = logging.getLogger(__name__)
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _now_str() -> str:
if _is_mysql():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return datetime.now().isoformat()
class TaskStatus(Enum):
"""任务状态"""
@@ -36,61 +54,105 @@ class ImageTaskManager:
def _init_db(self):
"""初始化数据库"""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建图片任务表
cursor.execute('''
CREATE TABLE IF NOT EXISTS image_tasks (
task_id TEXT PRIMARY KEY,
customer_id TEXT NOT NULL,
customer_name TEXT,
original_image TEXT NOT NULL,
operation TEXT DEFAULT 'enhance',
requirements TEXT, -- JSON 格式:复杂度、比例、透视等
customer_notes TEXT, -- 客户备注/需求细节
status TEXT DEFAULT 'pending',
created_at TEXT,
paid_at TEXT,
started_at TEXT,
completed_at TEXT,
result_image TEXT,
error_message TEXT,
retry_count INTEGER DEFAULT 0,
-- 店铺信息
acc_id TEXT,
acc_type TEXT DEFAULT 'AliWorkbench'
)
''')
# 创建需求变更记录表(支持客户后续增加需求)
cursor.execute('''
CREATE TABLE IF NOT EXISTS task_requirement_changes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
change_type TEXT, -- add_note/modify_operation/add_requirement
old_value TEXT,
new_value TEXT,
changed_at TEXT,
changed_by TEXT, -- customer/staff
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
)
''')
# 创建索引
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
conn.commit()
conn.close()
if _is_mysql():
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS image_tasks (
task_id VARCHAR(128) PRIMARY KEY,
customer_id VARCHAR(128) NOT NULL,
customer_name VARCHAR(255),
original_image TEXT NOT NULL,
operation VARCHAR(64) DEFAULT 'enhance',
requirements TEXT,
customer_notes TEXT,
status VARCHAR(32) DEFAULT 'pending',
created_at DATETIME,
paid_at DATETIME,
started_at DATETIME,
completed_at DATETIME,
result_image TEXT,
error_message TEXT,
retry_count INT DEFAULT 0,
acc_id VARCHAR(128),
acc_type VARCHAR(64) DEFAULT 'AliWorkbench'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS task_requirement_changes (
id INTEGER PRIMARY KEY AUTO_INCREMENT,
task_id VARCHAR(128) NOT NULL,
change_type VARCHAR(64),
old_value TEXT,
new_value TEXT,
changed_at DATETIME,
changed_by VARCHAR(32),
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
conn.commit()
conn.close()
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS image_tasks (
task_id TEXT PRIMARY KEY,
customer_id TEXT NOT NULL,
customer_name TEXT,
original_image TEXT NOT NULL,
operation TEXT DEFAULT 'enhance',
requirements TEXT,
customer_notes TEXT,
status TEXT DEFAULT 'pending',
created_at TEXT,
paid_at TEXT,
started_at TEXT,
completed_at TEXT,
result_image TEXT,
error_message TEXT,
retry_count INTEGER DEFAULT 0,
acc_id TEXT,
acc_type TEXT DEFAULT 'AliWorkbench'
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS task_requirement_changes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
change_type TEXT,
old_value TEXT,
new_value TEXT,
changed_at TEXT,
changed_by TEXT,
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
conn.commit()
conn.close()
logger.info("数据库表初始化完成")
def _get_conn(self):
"""获取数据库连接"""
if _is_mysql():
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
@@ -105,13 +167,13 @@ class ImageTaskManager:
requirements_json = json.dumps(requirements) if requirements else None
cursor.execute('''
cursor.execute(_sql('''
INSERT INTO image_tasks (
task_id, customer_id, customer_name, original_image,
operation, requirements, customer_notes, status,
created_at, acc_id, acc_type
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
'''), (
task_id,
customer_id,
customer_name,
@@ -120,7 +182,7 @@ class ImageTaskManager:
requirements_json,
'', # 初始备注为空
TaskStatus.PENDING.value,
datetime.now().isoformat(),
_now_str(),
acc_id,
acc_type
))
@@ -141,7 +203,7 @@ class ImageTaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('SELECT * FROM image_tasks WHERE task_id = ?', (task_id,))
cursor.execute(_sql('SELECT * FROM image_tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
conn.close()
@@ -164,17 +226,17 @@ class ImageTaskManager:
cursor = conn.cursor()
if status:
cursor.execute('''
cursor.execute(_sql('''
SELECT * FROM image_tasks
WHERE customer_id = ? AND status = ?
ORDER BY created_at DESC
''', (customer_id, status))
'''), (customer_id, status))
else:
cursor.execute('''
cursor.execute(_sql('''
SELECT * FROM image_tasks
WHERE customer_id = ?
ORDER BY created_at DESC
''', (customer_id,))
'''), (customer_id,))
rows = cursor.fetchall()
conn.close()
@@ -198,27 +260,28 @@ class ImageTaskManager:
conn = self._get_conn()
cursor = conn.cursor()
updates = ['status = ?']
placeholder = "%s" if _is_mysql() else "?"
updates = [f'status = {placeholder}']
params = [status.value]
# 根据状态设置时间
if status == TaskStatus.PAID:
updates.append('paid_at = ?')
params.append(datetime.now().isoformat())
updates.append(f'paid_at = {placeholder}')
params.append(_now_str())
elif status == TaskStatus.PROCESSING:
updates.append('started_at = ?')
params.append(datetime.now().isoformat())
updates.append(f'started_at = {placeholder}')
params.append(_now_str())
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
updates.append('completed_at = ?')
params.append(datetime.now().isoformat())
updates.append(f'completed_at = {placeholder}')
params.append(_now_str())
params.append(task_id)
cursor.execute(f'''
cursor.execute(_sql(f'''
UPDATE image_tasks
SET {', '.join(updates)}
WHERE task_id = ?
''', params)
'''), params)
conn.commit()
conn.close()
@@ -234,11 +297,11 @@ class ImageTaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
cursor.execute(_sql('''
UPDATE image_tasks
SET result_image = ?, error_message = ?
WHERE task_id = ?
''', (result_image, error_message, task_id))
'''), (result_image, error_message, task_id))
conn.commit()
conn.close()
@@ -265,30 +328,30 @@ class ImageTaskManager:
cursor = conn.cursor()
# 获取旧备注
cursor.execute('SELECT customer_notes FROM image_tasks WHERE task_id = ?', (task_id,))
cursor.execute(_sql('SELECT customer_notes FROM image_tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
old_note = row['customer_notes'] if row else ''
# 更新备注
new_note = f"{old_note}\n[{datetime.now().strftime('%m-%d %H:%M')}] {note}" if old_note else f"[{datetime.now().strftime('%m-%d %H:%M')}] {note}"
cursor.execute('''
cursor.execute(_sql('''
UPDATE image_tasks
SET customer_notes = ?
WHERE task_id = ?
''', (new_note, task_id))
'''), (new_note, task_id))
# 记录变更历史
cursor.execute('''
cursor.execute(_sql('''
INSERT INTO task_requirement_changes (
task_id, change_type, old_value, new_value, changed_at, changed_by
) VALUES (?, ?, ?, ?, ?, ?)
''', (
'''), (
task_id,
'add_note',
old_note or '',
note,
datetime.now().isoformat(),
_now_str(),
changed_by
))
@@ -319,28 +382,28 @@ class ImageTaskManager:
cursor = conn.cursor()
# 获取旧操作
cursor.execute('SELECT operation FROM image_tasks WHERE task_id = ?', (task_id,))
cursor.execute(_sql('SELECT operation FROM image_tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
old_operation = row['operation'] if row else ''
# 更新操作
cursor.execute('''
cursor.execute(_sql('''
UPDATE image_tasks
SET operation = ?
WHERE task_id = ?
''', (new_operation, task_id))
'''), (new_operation, task_id))
# 记录变更历史
cursor.execute('''
cursor.execute(_sql('''
INSERT INTO task_requirement_changes (
task_id, change_type, old_value, new_value, changed_at, changed_by
) VALUES (?, ?, ?, ?, ?, ?)
''', (
'''), (
task_id,
'modify_operation',
old_operation,
new_operation,
datetime.now().isoformat(),
_now_str(),
changed_by
))
@@ -360,11 +423,11 @@ class ImageTaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
cursor.execute(_sql('''
SELECT * FROM task_requirement_changes
WHERE task_id = ?
ORDER BY changed_at DESC
''', (task_id,))
'''), (task_id,))
rows = cursor.fetchall()
conn.close()
@@ -385,13 +448,13 @@ class ImageTaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
cursor.execute(_sql('''
UPDATE image_tasks
SET retry_count = retry_count + 1
WHERE task_id = ?
''', (task_id,))
'''), (task_id,))
cursor.execute('SELECT retry_count FROM image_tasks WHERE task_id = ?', (task_id,))
cursor.execute(_sql('SELECT retry_count FROM image_tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
conn.close()

View File

@@ -9,8 +9,26 @@ from datetime import datetime, timedelta
from typing import Optional, Dict, List
from pathlib import Path
from enum import Enum
import os
logger = logging.getLogger(__name__)
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _now_str() -> str:
if _is_mysql():
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return datetime.now().isoformat()
class TaskStatus(Enum):
"""任务状态"""
@@ -40,51 +58,93 @@ class TaskManager:
def _init_db(self):
"""初始化数据库"""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建任务表
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
specified_customer_id TEXT,
specified_customer_name TEXT,
type TEXT NOT NULL,
customer_name TEXT,
customer_id TEXT,
trigger_type TEXT,
trigger_keyword TEXT,
trigger_keywords TEXT, -- JSON array
action_type TEXT,
action_file_url TEXT,
action_message TEXT,
priority TEXT DEFAULT 'normal',
timeout_hours INTEGER DEFAULT 24,
status TEXT DEFAULT 'pending',
retry_count INTEGER DEFAULT 0,
max_retry INTEGER DEFAULT 3,
created_at TEXT,
created_by TEXT,
triggered_at TEXT,
completed_at TEXT,
error_message TEXT,
result TEXT -- JSON
)
''')
# 创建索引
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
conn.commit()
conn.close()
if _is_mysql():
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id VARCHAR(128) PRIMARY KEY,
specified_customer_id VARCHAR(128),
specified_customer_name VARCHAR(255),
type VARCHAR(64) NOT NULL,
customer_name VARCHAR(255),
customer_id VARCHAR(128),
trigger_type VARCHAR(64),
trigger_keyword VARCHAR(255),
trigger_keywords TEXT,
action_type VARCHAR(64),
action_file_url TEXT,
action_message TEXT,
priority VARCHAR(16) DEFAULT 'normal',
timeout_hours INT DEFAULT 24,
status VARCHAR(32) DEFAULT 'pending',
retry_count INT DEFAULT 0,
max_retry INT DEFAULT 3,
created_at DATETIME,
created_by VARCHAR(255),
triggered_at DATETIME,
completed_at DATETIME,
error_message TEXT,
result TEXT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
conn.commit()
conn.close()
else:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS tasks (
task_id TEXT PRIMARY KEY,
specified_customer_id TEXT,
specified_customer_name TEXT,
type TEXT NOT NULL,
customer_name TEXT,
customer_id TEXT,
trigger_type TEXT,
trigger_keyword TEXT,
trigger_keywords TEXT,
action_type TEXT,
action_file_url TEXT,
action_message TEXT,
priority TEXT DEFAULT 'normal',
timeout_hours INTEGER DEFAULT 24,
status TEXT DEFAULT 'pending',
retry_count INTEGER DEFAULT 0,
max_retry INTEGER DEFAULT 3,
created_at TEXT,
created_by TEXT,
triggered_at TEXT,
completed_at TEXT,
error_message TEXT,
result TEXT
)
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
conn.commit()
conn.close()
logger.info("数据库表初始化完成")
def _get_conn(self):
"""获取数据库连接"""
if _is_mysql():
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
@@ -100,7 +160,7 @@ class TaskManager:
if isinstance(trigger_keywords, list):
trigger_keywords = json.dumps(trigger_keywords)
cursor.execute('''
insert_sql = '''
INSERT OR REPLACE INTO tasks (
task_id, specified_customer_id, specified_customer_name,
type, customer_name, customer_id,
@@ -109,7 +169,19 @@ class TaskManager:
priority, timeout_hours, status, retry_count, max_retry,
created_at, created_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
'''
if _is_mysql():
insert_sql = '''
REPLACE INTO tasks (
task_id, specified_customer_id, specified_customer_name,
type, customer_name, customer_id,
trigger_type, trigger_keyword, trigger_keywords,
action_type, action_file_url, action_message,
priority, timeout_hours, status, retry_count, max_retry,
created_at, created_by
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
'''
cursor.execute(_sql(insert_sql), (
task.get('task_id'),
task.get('customer', {}).get('id'),
task.get('customer', {}).get('name'),
@@ -127,7 +199,7 @@ class TaskManager:
task.get('status', 'pending'),
task.get('retry_count', 0),
task.get('max_retry', 3),
task.get('created_at', datetime.now().isoformat()),
task.get('created_at', _now_str()),
task.get('created_by')
))
@@ -147,7 +219,7 @@ class TaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
cursor.execute(_sql('SELECT * FROM tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
conn.close()
@@ -165,32 +237,33 @@ class TaskManager:
conn = self._get_conn()
cursor = conn.cursor()
updates = ['status = ?']
placeholder = "%s" if _is_mysql() else "?"
updates = [f'status = {placeholder}']
params = [status.value]
if status == TaskStatus.RUNNING:
updates.append('triggered_at = ?')
params.append(datetime.now().isoformat())
updates.append(f'triggered_at = {placeholder}')
params.append(_now_str())
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
updates.append('completed_at = ?')
params.append(datetime.now().isoformat())
updates.append(f'completed_at = {placeholder}')
params.append(_now_str())
if error_message:
updates.append('error_message = ?')
updates.append(f'error_message = {placeholder}')
params.append(error_message)
if result:
updates.append('result = ?')
updates.append(f'result = {placeholder}')
params.append(json.dumps(result))
params.append(task_id)
cursor.execute(f'''
cursor.execute(_sql(f'''
UPDATE tasks
SET {', '.join(updates)}
WHERE task_id = ?
''', params)
'''), params)
conn.commit()
conn.close()
@@ -206,13 +279,13 @@ class TaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
cursor.execute(_sql('''
UPDATE tasks
SET retry_count = retry_count + 1
WHERE task_id = ?
''', (task_id,))
'''), (task_id,))
cursor.execute('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?', (task_id,))
cursor.execute(_sql('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?'), (task_id,))
row = cursor.fetchone()
conn.close()
@@ -231,7 +304,7 @@ class TaskManager:
cursor = conn.cursor()
if customer_id:
cursor.execute('''
cursor.execute(_sql('''
SELECT * FROM tasks
WHERE status = 'pending'
AND customer_id = ?
@@ -242,7 +315,7 @@ class TaskManager:
ELSE 3
END,
created_at
''', (customer_id,))
'''), (customer_id,))
else:
cursor.execute('''
SELECT * FROM tasks
@@ -271,11 +344,18 @@ class TaskManager:
conn = self._get_conn()
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
''')
if _is_mysql():
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
AND created_at < DATE_SUB(NOW(), INTERVAL timeout_hours HOUR)
''')
else:
cursor.execute('''
SELECT * FROM tasks
WHERE status = 'pending'
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
''')
rows = cursor.fetchall()
conn.close()
@@ -299,7 +379,7 @@ class TaskManager:
stats = {}
for status in TaskStatus:
cursor.execute('SELECT COUNT(*) FROM tasks WHERE status = ?', (status.value,))
cursor.execute(_sql('SELECT COUNT(*) FROM tasks WHERE status = ?'), (status.value,))
stats[status.value] = cursor.fetchone()[0]
conn.close()