feat: 添加 AI Agent 对话测试工具 + 代码优化
主要变更: - 新增 tests/test_ai_chat.py: AI Agent 对话测试工具 - 优化 core/pydantic_ai_agent.py 和 db/chat_log_db.py - 清理归档文件,更新文档 Made-with: Cursor
This commit is contained in:
@@ -9,9 +9,33 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
_DB_PATH = os.path.join(os.path.dirname(__file__), "chat_log_db", "chats.db")
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -21,27 +45,44 @@ def _get_conn() -> sqlite3.Connection:
|
||||
def init_db():
|
||||
"""建表(首次运行时自动调用)"""
|
||||
with _get_conn() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT DEFAULT '',
|
||||
acc_id TEXT DEFAULT '',
|
||||
platform TEXT DEFAULT '',
|
||||
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
|
||||
message TEXT NOT NULL,
|
||||
msg_type INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
# 兼容旧表:若缺少 acc_id 列则补上(必须在创建该列索引之前)
|
||||
try:
|
||||
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
|
||||
except Exception:
|
||||
pass
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
if _is_mysql():
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
||||
customer_id VARCHAR(128) NOT NULL,
|
||||
customer_name VARCHAR(255) DEFAULT '',
|
||||
acc_id VARCHAR(128) DEFAULT '',
|
||||
platform VARCHAR(64) DEFAULT '',
|
||||
direction VARCHAR(8) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
msg_type INTEGER DEFAULT 0,
|
||||
timestamp DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT DEFAULT '',
|
||||
acc_id TEXT DEFAULT '',
|
||||
platform TEXT DEFAULT '',
|
||||
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
|
||||
message TEXT NOT NULL,
|
||||
msg_type INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
try:
|
||||
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
|
||||
except Exception:
|
||||
pass
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -63,9 +104,9 @@ def log_message(
|
||||
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO chat_logs "
|
||||
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
|
||||
"VALUES (?,?,?,?,?,?,?,?)",
|
||||
_sql("INSERT INTO chat_logs "
|
||||
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
|
||||
"VALUES (?,?,?,?,?,?,?,?)"),
|
||||
(customer_id, customer_name, acc_id, platform, direction, message, msg_type, ts),
|
||||
)
|
||||
conn.commit()
|
||||
@@ -77,6 +118,19 @@ def get_customers(limit: int = 100) -> List[Dict]:
|
||||
"""返回所有有记录的客户列表(按最新消息时间排序)"""
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
MAX(platform) AS platform,
|
||||
COUNT(*) AS total_msgs,
|
||||
SUM(direction='in') AS recv,
|
||||
SUM(direction='out') AS sent,
|
||||
MAX(timestamp) AS last_time
|
||||
FROM chat_logs
|
||||
GROUP BY customer_id
|
||||
ORDER BY last_time DESC
|
||||
LIMIT %s
|
||||
""" if _is_mysql() else """
|
||||
SELECT
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
@@ -96,13 +150,13 @@ def get_customers(limit: int = 100) -> List[Dict]:
|
||||
def get_conversation(customer_id: str, limit: int = 200) -> List[Dict]:
|
||||
"""返回某客户的全部对话记录(按时间升序)"""
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, msg_type, timestamp, acc_id
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ?
|
||||
ORDER BY timestamp ASC, id ASC
|
||||
LIMIT ?
|
||||
""", (customer_id, limit)).fetchall()
|
||||
"""), (customer_id, limit)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -110,21 +164,21 @@ def get_recent_conversation(customer_id: str, acc_id: str = "", limit: int = 10)
|
||||
"""返回某客户近期对话(同店铺),用于企微推送保持连贯"""
|
||||
with _get_conn() as conn:
|
||||
if acc_id:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, timestamp, acc_id
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ? AND acc_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
""", (customer_id, acc_id, limit)).fetchall()
|
||||
"""), (customer_id, acc_id, limit)).fetchall()
|
||||
else:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, timestamp, acc_id
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
""", (customer_id, limit)).fetchall()
|
||||
"""), (customer_id, limit)).fetchall()
|
||||
out = [dict(r) for r in reversed(rows)]
|
||||
return out
|
||||
|
||||
@@ -133,12 +187,12 @@ def get_conversation_today(customer_id: str) -> List[Dict]:
|
||||
"""返回某客户今天的对话"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, msg_type, timestamp
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ? AND timestamp LIKE ?
|
||||
ORDER BY timestamp ASC, id ASC
|
||||
""", (customer_id, f"{today}%")).fetchall()
|
||||
"""), (customer_id, f"{today}%")).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -151,7 +205,7 @@ def get_daily_stats(date: str = "") -> List[Dict]:
|
||||
if not date:
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT
|
||||
acc_id,
|
||||
platform,
|
||||
@@ -165,7 +219,7 @@ def get_daily_stats(date: str = "") -> List[Dict]:
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id
|
||||
ORDER BY unique_customers DESC
|
||||
""", (f"{date}%",)).fetchall()
|
||||
"""), (f"{date}%",)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -176,22 +230,39 @@ def get_daily_conversations(date: str = "") -> List[Dict]:
|
||||
if not date:
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT
|
||||
acc_id,
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
COUNT(*) AS msg_count,
|
||||
GROUP_CONCAT(
|
||||
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
|
||||
ELSE '客:' || SUBSTR(message,1,40) END,
|
||||
' | '
|
||||
) AS snippet
|
||||
FROM chat_logs
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id, customer_id
|
||||
ORDER BY acc_id, MAX(timestamp) DESC
|
||||
""", (f"{date}%",)).fetchall()
|
||||
if _is_mysql():
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT
|
||||
acc_id,
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
COUNT(*) AS msg_count,
|
||||
GROUP_CONCAT(
|
||||
CONCAT(CASE WHEN direction='in' THEN '买:' ELSE '客:' END, LEFT(message,40))
|
||||
SEPARATOR ' | '
|
||||
) AS snippet
|
||||
FROM chat_logs
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id, customer_id
|
||||
ORDER BY acc_id, MAX(timestamp) DESC
|
||||
"""), (f"{date}%",)).fetchall()
|
||||
else:
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT
|
||||
acc_id,
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
COUNT(*) AS msg_count,
|
||||
GROUP_CONCAT(
|
||||
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
|
||||
ELSE '客:' || SUBSTR(message,1,40) END,
|
||||
' | '
|
||||
) AS snippet
|
||||
FROM chat_logs
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id, customer_id
|
||||
ORDER BY acc_id, MAX(timestamp) DESC
|
||||
"""), (f"{date}%",)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -199,18 +270,28 @@ def search_messages(keyword: str, customer_id: Optional[str] = None, limit: int
|
||||
"""全文搜索消息"""
|
||||
if customer_id:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT customer_id, customer_name, direction, message, timestamp
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ? AND message LIKE ?
|
||||
ORDER BY timestamp DESC LIMIT ?
|
||||
""", (customer_id, f"%{keyword}%", limit)).fetchall()
|
||||
"""), (customer_id, f"%{keyword}%", limit)).fetchall()
|
||||
else:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT customer_id, customer_name, direction, message, timestamp
|
||||
FROM chat_logs
|
||||
WHERE message LIKE ?
|
||||
ORDER BY timestamp DESC LIMIT ?
|
||||
""", (f"%{keyword}%", limit)).fetchall()
|
||||
"""), (f"%{keyword}%", limit)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def get_latest_messages(limit: int = 20) -> List[Dict]:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, customer_id, customer_name, direction, message, timestamp
|
||||
FROM chat_logs
|
||||
ORDER BY id DESC LIMIT ?
|
||||
"""), (limit,)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
@@ -8,9 +8,33 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
_DB_PATH = os.path.join(os.path.dirname(__file__), "deal_outcome_db", "outcomes.db")
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -19,26 +43,48 @@ def _get_conn() -> sqlite3.Connection:
|
||||
|
||||
def _init_db():
|
||||
with _get_conn() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS deal_outcomes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT DEFAULT '',
|
||||
acc_id TEXT DEFAULT '',
|
||||
platform TEXT DEFAULT '',
|
||||
date TEXT NOT NULL,
|
||||
outcome TEXT NOT NULL CHECK(outcome IN ('成交','未成交')),
|
||||
reason TEXT DEFAULT '',
|
||||
order_id TEXT DEFAULT '',
|
||||
amount REAL DEFAULT 0,
|
||||
discount_given INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
|
||||
if _is_mysql():
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS deal_outcomes (
|
||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
||||
customer_id VARCHAR(128) NOT NULL,
|
||||
customer_name VARCHAR(255) DEFAULT '',
|
||||
acc_id VARCHAR(128) DEFAULT '',
|
||||
platform VARCHAR(64) DEFAULT '',
|
||||
date DATE NOT NULL,
|
||||
outcome VARCHAR(16) NOT NULL,
|
||||
reason TEXT,
|
||||
order_id VARCHAR(128) DEFAULT '',
|
||||
amount REAL DEFAULT 0,
|
||||
discount_given INTEGER DEFAULT 0,
|
||||
timestamp DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS deal_outcomes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT DEFAULT '',
|
||||
acc_id TEXT DEFAULT '',
|
||||
platform TEXT DEFAULT '',
|
||||
date TEXT NOT NULL,
|
||||
outcome TEXT NOT NULL CHECK(outcome IN ('成交','未成交')),
|
||||
reason TEXT DEFAULT '',
|
||||
order_id TEXT DEFAULT '',
|
||||
amount REAL DEFAULT 0,
|
||||
discount_given INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -61,10 +107,10 @@ def record_deal(
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"""INSERT INTO deal_outcomes
|
||||
_sql("""INSERT INTO deal_outcomes
|
||||
(customer_id, customer_name, acc_id, platform, date, outcome, reason,
|
||||
order_id, amount, discount_given, timestamp)
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)""",
|
||||
VALUES (?,?,?,?,?,?,?,?,?,?,?)"""),
|
||||
(
|
||||
customer_id,
|
||||
customer_name or "",
|
||||
@@ -88,13 +134,13 @@ def get_daily_outcomes(date: str = "") -> List[Dict]:
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
_sql("""
|
||||
SELECT customer_id, customer_name, acc_id, outcome, reason,
|
||||
order_id, amount, discount_given, timestamp
|
||||
FROM deal_outcomes
|
||||
WHERE date = ?
|
||||
ORDER BY timestamp ASC
|
||||
""",
|
||||
"""),
|
||||
(date,),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
@@ -131,19 +177,19 @@ def export_for_analysis(start_date: str = "", end_date: str = "") -> List[Dict]:
|
||||
with _get_conn() as conn:
|
||||
if start_date and end_date:
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM deal_outcomes
|
||||
_sql("""SELECT * FROM deal_outcomes
|
||||
WHERE date BETWEEN ? AND ?
|
||||
ORDER BY date, timestamp""",
|
||||
ORDER BY date, timestamp"""),
|
||||
(start_date, end_date),
|
||||
).fetchall()
|
||||
elif start_date:
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM deal_outcomes WHERE date >= ? ORDER BY date, timestamp""",
|
||||
_sql("""SELECT * FROM deal_outcomes WHERE date >= ? ORDER BY date, timestamp"""),
|
||||
(start_date,),
|
||||
).fetchall()
|
||||
elif end_date:
|
||||
rows = conn.execute(
|
||||
"""SELECT * FROM deal_outcomes WHERE date <= ? ORDER BY date, timestamp""",
|
||||
_sql("""SELECT * FROM deal_outcomes WHERE date <= ? ORDER BY date, timestamp"""),
|
||||
(end_date,),
|
||||
).fetchall()
|
||||
else:
|
||||
|
||||
@@ -10,9 +10,33 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
_DB_PATH = os.path.join(os.path.dirname(__file__), "designer_roster_db", "roster.db")
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -21,35 +45,66 @@ def _get_conn() -> sqlite3.Connection:
|
||||
|
||||
def init_db():
|
||||
with _get_conn() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
wechat_user_id TEXT UNIQUE NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designer_shops (
|
||||
designer_id INTEGER NOT NULL,
|
||||
shop_id TEXT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
PRIMARY KEY (designer_id, shop_id),
|
||||
FOREIGN KEY (designer_id) REFERENCES designers(id)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designer_online (
|
||||
wechat_user_id TEXT PRIMARY KEY,
|
||||
is_online INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS round_robin (
|
||||
shop_id TEXT PRIMARY KEY,
|
||||
last_index INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
""")
|
||||
if _is_mysql():
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designers (
|
||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
wechat_user_id VARCHAR(128) UNIQUE NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designer_shops (
|
||||
designer_id INTEGER NOT NULL,
|
||||
shop_id VARCHAR(128) NOT NULL,
|
||||
group_id VARCHAR(128) NOT NULL,
|
||||
PRIMARY KEY (designer_id, shop_id),
|
||||
FOREIGN KEY (designer_id) REFERENCES designers(id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designer_online (
|
||||
wechat_user_id VARCHAR(128) PRIMARY KEY,
|
||||
is_online INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at DATETIME
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS round_robin (
|
||||
shop_id VARCHAR(128) PRIMARY KEY,
|
||||
last_index INTEGER NOT NULL DEFAULT 0
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
wechat_user_id TEXT UNIQUE NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designer_shops (
|
||||
designer_id INTEGER NOT NULL,
|
||||
shop_id TEXT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
PRIMARY KEY (designer_id, shop_id),
|
||||
FOREIGN KEY (designer_id) REFERENCES designers(id)
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS designer_online (
|
||||
wechat_user_id TEXT PRIMARY KEY,
|
||||
is_online INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS round_robin (
|
||||
shop_id TEXT PRIMARY KEY,
|
||||
last_index INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -61,22 +116,34 @@ init_db()
|
||||
def add_designer(name: str, wechat_user_id: str) -> int:
|
||||
"""添加设计师,返回 id"""
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO designers (name, wechat_user_id) VALUES (?, ?)",
|
||||
(name, wechat_user_id),
|
||||
)
|
||||
if _is_mysql():
|
||||
conn.execute(
|
||||
"INSERT IGNORE INTO designers (name, wechat_user_id) VALUES (%s, %s)",
|
||||
(name, wechat_user_id),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"INSERT OR IGNORE INTO designers (name, wechat_user_id) VALUES (?, ?)",
|
||||
(name, wechat_user_id),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute("SELECT id FROM designers WHERE wechat_user_id = ?", (wechat_user_id,)).fetchone()
|
||||
row = conn.execute(_sql("SELECT id FROM designers WHERE wechat_user_id = ?"), (wechat_user_id,)).fetchone()
|
||||
return row["id"] if row else 0
|
||||
|
||||
|
||||
def set_designer_shop(designer_id: int, shop_id: str, group_id: str):
|
||||
"""设置设计师在某店铺的分组 ID(同一设计师不同店铺不同 group_id)"""
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO designer_shops (designer_id, shop_id, group_id) VALUES (?, ?, ?)",
|
||||
(designer_id, shop_id, group_id),
|
||||
)
|
||||
if _is_mysql():
|
||||
conn.execute(
|
||||
"REPLACE INTO designer_shops (designer_id, shop_id, group_id) VALUES (%s, %s, %s)",
|
||||
(designer_id, shop_id, group_id),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO designer_shops (designer_id, shop_id, group_id) VALUES (?, ?, ?)",
|
||||
(designer_id, shop_id, group_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -85,10 +152,16 @@ def update_online(wechat_user_id: str, is_online: bool):
|
||||
from datetime import datetime
|
||||
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO designer_online (wechat_user_id, is_online, updated_at) VALUES (?, ?, ?)",
|
||||
(wechat_user_id, 1 if is_online else 0, ts),
|
||||
)
|
||||
if _is_mysql():
|
||||
conn.execute(
|
||||
"REPLACE INTO designer_online (wechat_user_id, is_online, updated_at) VALUES (%s, %s, %s)",
|
||||
(wechat_user_id, 1 if is_online else 0, ts),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO designer_online (wechat_user_id, is_online, updated_at) VALUES (?, ?, ?)",
|
||||
(wechat_user_id, 1 if is_online else 0, ts),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -101,26 +174,32 @@ def get_transfer_group_for_shop(shop_id: str) -> Optional[str]:
|
||||
无人在线则返回 None。
|
||||
"""
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT d.wechat_user_id, ds.group_id
|
||||
FROM designer_shops ds
|
||||
JOIN designers d ON d.id = ds.designer_id
|
||||
JOIN designer_online o ON o.wechat_user_id = d.wechat_user_id AND o.is_online = 1
|
||||
WHERE ds.shop_id = ?
|
||||
""", (shop_id,)).fetchall()
|
||||
"""), (shop_id,)).fetchall()
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
with _get_conn() as conn:
|
||||
rr = conn.execute("SELECT last_index FROM round_robin WHERE shop_id = ?", (shop_id,)).fetchone()
|
||||
rr = conn.execute(_sql("SELECT last_index FROM round_robin WHERE shop_id = ?"), (shop_id,)).fetchone()
|
||||
last = rr["last_index"] if rr else 0
|
||||
idx = last % len(rows)
|
||||
chosen = rows[idx]
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO round_robin (shop_id, last_index) VALUES (?, ?)",
|
||||
(shop_id, idx + 1),
|
||||
)
|
||||
if _is_mysql():
|
||||
conn.execute(
|
||||
"REPLACE INTO round_robin (shop_id, last_index) VALUES (%s, %s)",
|
||||
(shop_id, idx + 1),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"INSERT OR REPLACE INTO round_robin (shop_id, last_index) VALUES (?, ?)",
|
||||
(shop_id, idx + 1),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
return chosen["group_id"]
|
||||
@@ -142,11 +221,11 @@ def list_designers():
|
||||
result = []
|
||||
for d in designers:
|
||||
shops = conn.execute(
|
||||
"SELECT shop_id, group_id FROM designer_shops WHERE designer_id = ?",
|
||||
_sql("SELECT shop_id, group_id FROM designer_shops WHERE designer_id = ?"),
|
||||
(d["id"],),
|
||||
).fetchall()
|
||||
online = conn.execute(
|
||||
"SELECT is_online FROM designer_online WHERE wechat_user_id = ?",
|
||||
_sql("SELECT is_online FROM designer_online WHERE wechat_user_id = ?"),
|
||||
(d["wechat_user_id"],),
|
||||
).fetchone()
|
||||
result.append({
|
||||
|
||||
@@ -10,8 +10,26 @@ from typing import Optional, List, Dict
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
def _now_str() -> str:
|
||||
if _is_mysql():
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
return datetime.now().isoformat()
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务状态"""
|
||||
@@ -36,61 +54,105 @@ class ImageTaskManager:
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库"""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建图片任务表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS image_tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT,
|
||||
original_image TEXT NOT NULL,
|
||||
operation TEXT DEFAULT 'enhance',
|
||||
requirements TEXT, -- JSON 格式:复杂度、比例、透视等
|
||||
customer_notes TEXT, -- 客户备注/需求细节
|
||||
status TEXT DEFAULT 'pending',
|
||||
created_at TEXT,
|
||||
paid_at TEXT,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
result_image TEXT,
|
||||
error_message TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
|
||||
-- 店铺信息
|
||||
acc_id TEXT,
|
||||
acc_type TEXT DEFAULT 'AliWorkbench'
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建需求变更记录表(支持客户后续增加需求)
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS task_requirement_changes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
change_type TEXT, -- add_note/modify_operation/add_requirement
|
||||
old_value TEXT,
|
||||
new_value TEXT,
|
||||
changed_at TEXT,
|
||||
changed_by TEXT, -- customer/staff
|
||||
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建索引
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
if _is_mysql():
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS image_tasks (
|
||||
task_id VARCHAR(128) PRIMARY KEY,
|
||||
customer_id VARCHAR(128) NOT NULL,
|
||||
customer_name VARCHAR(255),
|
||||
original_image TEXT NOT NULL,
|
||||
operation VARCHAR(64) DEFAULT 'enhance',
|
||||
requirements TEXT,
|
||||
customer_notes TEXT,
|
||||
status VARCHAR(32) DEFAULT 'pending',
|
||||
created_at DATETIME,
|
||||
paid_at DATETIME,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
result_image TEXT,
|
||||
error_message TEXT,
|
||||
retry_count INT DEFAULT 0,
|
||||
acc_id VARCHAR(128),
|
||||
acc_type VARCHAR(64) DEFAULT 'AliWorkbench'
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS task_requirement_changes (
|
||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
||||
task_id VARCHAR(128) NOT NULL,
|
||||
change_type VARCHAR(64),
|
||||
old_value TEXT,
|
||||
new_value TEXT,
|
||||
changed_at DATETIME,
|
||||
changed_by VARCHAR(32),
|
||||
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
else:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS image_tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT,
|
||||
original_image TEXT NOT NULL,
|
||||
operation TEXT DEFAULT 'enhance',
|
||||
requirements TEXT,
|
||||
customer_notes TEXT,
|
||||
status TEXT DEFAULT 'pending',
|
||||
created_at TEXT,
|
||||
paid_at TEXT,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
result_image TEXT,
|
||||
error_message TEXT,
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
acc_id TEXT,
|
||||
acc_type TEXT DEFAULT 'AliWorkbench'
|
||||
)
|
||||
''')
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS task_requirement_changes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
change_type TEXT,
|
||||
old_value TEXT,
|
||||
new_value TEXT,
|
||||
changed_at TEXT,
|
||||
changed_by TEXT,
|
||||
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
|
||||
)
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("数据库表初始化完成")
|
||||
|
||||
def _get_conn(self):
|
||||
"""获取数据库连接"""
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
@@ -105,13 +167,13 @@ class ImageTaskManager:
|
||||
|
||||
requirements_json = json.dumps(requirements) if requirements else None
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
INSERT INTO image_tasks (
|
||||
task_id, customer_id, customer_name, original_image,
|
||||
operation, requirements, customer_notes, status,
|
||||
created_at, acc_id, acc_type
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
'''), (
|
||||
task_id,
|
||||
customer_id,
|
||||
customer_name,
|
||||
@@ -120,7 +182,7 @@ class ImageTaskManager:
|
||||
requirements_json,
|
||||
'', # 初始备注为空
|
||||
TaskStatus.PENDING.value,
|
||||
datetime.now().isoformat(),
|
||||
_now_str(),
|
||||
acc_id,
|
||||
acc_type
|
||||
))
|
||||
@@ -141,7 +203,7 @@ class ImageTaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT * FROM image_tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT * FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -164,17 +226,17 @@ class ImageTaskManager:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if status:
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
SELECT * FROM image_tasks
|
||||
WHERE customer_id = ? AND status = ?
|
||||
ORDER BY created_at DESC
|
||||
''', (customer_id, status))
|
||||
'''), (customer_id, status))
|
||||
else:
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
SELECT * FROM image_tasks
|
||||
WHERE customer_id = ?
|
||||
ORDER BY created_at DESC
|
||||
''', (customer_id,))
|
||||
'''), (customer_id,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
@@ -198,27 +260,28 @@ class ImageTaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
updates = ['status = ?']
|
||||
placeholder = "%s" if _is_mysql() else "?"
|
||||
updates = [f'status = {placeholder}']
|
||||
params = [status.value]
|
||||
|
||||
# 根据状态设置时间
|
||||
if status == TaskStatus.PAID:
|
||||
updates.append('paid_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'paid_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
elif status == TaskStatus.PROCESSING:
|
||||
updates.append('started_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'started_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
elif status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
|
||||
updates.append('completed_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'completed_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
|
||||
params.append(task_id)
|
||||
|
||||
cursor.execute(f'''
|
||||
cursor.execute(_sql(f'''
|
||||
UPDATE image_tasks
|
||||
SET {', '.join(updates)}
|
||||
WHERE task_id = ?
|
||||
''', params)
|
||||
'''), params)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -234,11 +297,11 @@ class ImageTaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
UPDATE image_tasks
|
||||
SET result_image = ?, error_message = ?
|
||||
WHERE task_id = ?
|
||||
''', (result_image, error_message, task_id))
|
||||
'''), (result_image, error_message, task_id))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -265,30 +328,30 @@ class ImageTaskManager:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取旧备注
|
||||
cursor.execute('SELECT customer_notes FROM image_tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT customer_notes FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
old_note = row['customer_notes'] if row else ''
|
||||
|
||||
# 更新备注
|
||||
new_note = f"{old_note}\n[{datetime.now().strftime('%m-%d %H:%M')}] {note}" if old_note else f"[{datetime.now().strftime('%m-%d %H:%M')}] {note}"
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
UPDATE image_tasks
|
||||
SET customer_notes = ?
|
||||
WHERE task_id = ?
|
||||
''', (new_note, task_id))
|
||||
'''), (new_note, task_id))
|
||||
|
||||
# 记录变更历史
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
INSERT INTO task_requirement_changes (
|
||||
task_id, change_type, old_value, new_value, changed_at, changed_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
'''), (
|
||||
task_id,
|
||||
'add_note',
|
||||
old_note or '无',
|
||||
note,
|
||||
datetime.now().isoformat(),
|
||||
_now_str(),
|
||||
changed_by
|
||||
))
|
||||
|
||||
@@ -319,28 +382,28 @@ class ImageTaskManager:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取旧操作
|
||||
cursor.execute('SELECT operation FROM image_tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT operation FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
old_operation = row['operation'] if row else ''
|
||||
|
||||
# 更新操作
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
UPDATE image_tasks
|
||||
SET operation = ?
|
||||
WHERE task_id = ?
|
||||
''', (new_operation, task_id))
|
||||
'''), (new_operation, task_id))
|
||||
|
||||
# 记录变更历史
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
INSERT INTO task_requirement_changes (
|
||||
task_id, change_type, old_value, new_value, changed_at, changed_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
'''), (
|
||||
task_id,
|
||||
'modify_operation',
|
||||
old_operation,
|
||||
new_operation,
|
||||
datetime.now().isoformat(),
|
||||
_now_str(),
|
||||
changed_by
|
||||
))
|
||||
|
||||
@@ -360,11 +423,11 @@ class ImageTaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
SELECT * FROM task_requirement_changes
|
||||
WHERE task_id = ?
|
||||
ORDER BY changed_at DESC
|
||||
''', (task_id,))
|
||||
'''), (task_id,))
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
@@ -385,13 +448,13 @@ class ImageTaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
UPDATE image_tasks
|
||||
SET retry_count = retry_count + 1
|
||||
WHERE task_id = ?
|
||||
''', (task_id,))
|
||||
'''), (task_id,))
|
||||
|
||||
cursor.execute('SELECT retry_count FROM image_tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT retry_count FROM image_tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
|
||||
@@ -9,8 +9,26 @@ from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, List
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
def _now_str() -> str:
|
||||
if _is_mysql():
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
return datetime.now().isoformat()
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务状态"""
|
||||
@@ -40,51 +58,93 @@ class TaskManager:
|
||||
|
||||
def _init_db(self):
|
||||
"""初始化数据库"""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建任务表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
specified_customer_id TEXT,
|
||||
specified_customer_name TEXT,
|
||||
type TEXT NOT NULL,
|
||||
customer_name TEXT,
|
||||
customer_id TEXT,
|
||||
trigger_type TEXT,
|
||||
trigger_keyword TEXT,
|
||||
trigger_keywords TEXT, -- JSON array
|
||||
action_type TEXT,
|
||||
action_file_url TEXT,
|
||||
action_message TEXT,
|
||||
priority TEXT DEFAULT 'normal',
|
||||
timeout_hours INTEGER DEFAULT 24,
|
||||
status TEXT DEFAULT 'pending',
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retry INTEGER DEFAULT 3,
|
||||
created_at TEXT,
|
||||
created_by TEXT,
|
||||
triggered_at TEXT,
|
||||
completed_at TEXT,
|
||||
error_message TEXT,
|
||||
result TEXT -- JSON
|
||||
)
|
||||
''')
|
||||
|
||||
# 创建索引
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
if _is_mysql():
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id VARCHAR(128) PRIMARY KEY,
|
||||
specified_customer_id VARCHAR(128),
|
||||
specified_customer_name VARCHAR(255),
|
||||
type VARCHAR(64) NOT NULL,
|
||||
customer_name VARCHAR(255),
|
||||
customer_id VARCHAR(128),
|
||||
trigger_type VARCHAR(64),
|
||||
trigger_keyword VARCHAR(255),
|
||||
trigger_keywords TEXT,
|
||||
action_type VARCHAR(64),
|
||||
action_file_url TEXT,
|
||||
action_message TEXT,
|
||||
priority VARCHAR(16) DEFAULT 'normal',
|
||||
timeout_hours INT DEFAULT 24,
|
||||
status VARCHAR(32) DEFAULT 'pending',
|
||||
retry_count INT DEFAULT 0,
|
||||
max_retry INT DEFAULT 3,
|
||||
created_at DATETIME,
|
||||
created_by VARCHAR(255),
|
||||
triggered_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
error_message TEXT,
|
||||
result TEXT
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
else:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
task_id TEXT PRIMARY KEY,
|
||||
specified_customer_id TEXT,
|
||||
specified_customer_name TEXT,
|
||||
type TEXT NOT NULL,
|
||||
customer_name TEXT,
|
||||
customer_id TEXT,
|
||||
trigger_type TEXT,
|
||||
trigger_keyword TEXT,
|
||||
trigger_keywords TEXT,
|
||||
action_type TEXT,
|
||||
action_file_url TEXT,
|
||||
action_message TEXT,
|
||||
priority TEXT DEFAULT 'normal',
|
||||
timeout_hours INTEGER DEFAULT 24,
|
||||
status TEXT DEFAULT 'pending',
|
||||
retry_count INTEGER DEFAULT 0,
|
||||
max_retry INTEGER DEFAULT 3,
|
||||
created_at TEXT,
|
||||
created_by TEXT,
|
||||
triggered_at TEXT,
|
||||
completed_at TEXT,
|
||||
error_message TEXT,
|
||||
result TEXT
|
||||
)
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("数据库表初始化完成")
|
||||
|
||||
def _get_conn(self):
|
||||
"""获取数据库连接"""
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
@@ -100,7 +160,7 @@ class TaskManager:
|
||||
if isinstance(trigger_keywords, list):
|
||||
trigger_keywords = json.dumps(trigger_keywords)
|
||||
|
||||
cursor.execute('''
|
||||
insert_sql = '''
|
||||
INSERT OR REPLACE INTO tasks (
|
||||
task_id, specified_customer_id, specified_customer_name,
|
||||
type, customer_name, customer_id,
|
||||
@@ -109,7 +169,19 @@ class TaskManager:
|
||||
priority, timeout_hours, status, retry_count, max_retry,
|
||||
created_at, created_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
'''
|
||||
if _is_mysql():
|
||||
insert_sql = '''
|
||||
REPLACE INTO tasks (
|
||||
task_id, specified_customer_id, specified_customer_name,
|
||||
type, customer_name, customer_id,
|
||||
trigger_type, trigger_keyword, trigger_keywords,
|
||||
action_type, action_file_url, action_message,
|
||||
priority, timeout_hours, status, retry_count, max_retry,
|
||||
created_at, created_by
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
'''
|
||||
cursor.execute(_sql(insert_sql), (
|
||||
task.get('task_id'),
|
||||
task.get('customer', {}).get('id'),
|
||||
task.get('customer', {}).get('name'),
|
||||
@@ -127,7 +199,7 @@ class TaskManager:
|
||||
task.get('status', 'pending'),
|
||||
task.get('retry_count', 0),
|
||||
task.get('max_retry', 3),
|
||||
task.get('created_at', datetime.now().isoformat()),
|
||||
task.get('created_at', _now_str()),
|
||||
task.get('created_by')
|
||||
))
|
||||
|
||||
@@ -147,7 +219,7 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT * FROM tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -165,32 +237,33 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
updates = ['status = ?']
|
||||
placeholder = "%s" if _is_mysql() else "?"
|
||||
updates = [f'status = {placeholder}']
|
||||
params = [status.value]
|
||||
|
||||
if status == TaskStatus.RUNNING:
|
||||
updates.append('triggered_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'triggered_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
|
||||
if status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
updates.append('completed_at = ?')
|
||||
params.append(datetime.now().isoformat())
|
||||
updates.append(f'completed_at = {placeholder}')
|
||||
params.append(_now_str())
|
||||
|
||||
if error_message:
|
||||
updates.append('error_message = ?')
|
||||
updates.append(f'error_message = {placeholder}')
|
||||
params.append(error_message)
|
||||
|
||||
if result:
|
||||
updates.append('result = ?')
|
||||
updates.append(f'result = {placeholder}')
|
||||
params.append(json.dumps(result))
|
||||
|
||||
params.append(task_id)
|
||||
|
||||
cursor.execute(f'''
|
||||
cursor.execute(_sql(f'''
|
||||
UPDATE tasks
|
||||
SET {', '.join(updates)}
|
||||
WHERE task_id = ?
|
||||
''', params)
|
||||
'''), params)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -206,13 +279,13 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
UPDATE tasks
|
||||
SET retry_count = retry_count + 1
|
||||
WHERE task_id = ?
|
||||
''', (task_id,))
|
||||
'''), (task_id,))
|
||||
|
||||
cursor.execute('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?', (task_id,))
|
||||
cursor.execute(_sql('SELECT retry_count, max_retry FROM tasks WHERE task_id = ?'), (task_id,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
@@ -231,7 +304,7 @@ class TaskManager:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if customer_id:
|
||||
cursor.execute('''
|
||||
cursor.execute(_sql('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND customer_id = ?
|
||||
@@ -242,7 +315,7 @@ class TaskManager:
|
||||
ELSE 3
|
||||
END,
|
||||
created_at
|
||||
''', (customer_id,))
|
||||
'''), (customer_id,))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
@@ -271,11 +344,18 @@ class TaskManager:
|
||||
conn = self._get_conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
|
||||
''')
|
||||
if _is_mysql():
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND created_at < DATE_SUB(NOW(), INTERVAL timeout_hours HOUR)
|
||||
''')
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND datetime(created_at, '+' || timeout_hours || ' hours') < datetime('now')
|
||||
''')
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
@@ -299,7 +379,7 @@ class TaskManager:
|
||||
|
||||
stats = {}
|
||||
for status in TaskStatus:
|
||||
cursor.execute('SELECT COUNT(*) FROM tasks WHERE status = ?', (status.value,))
|
||||
cursor.execute(_sql('SELECT COUNT(*) FROM tasks WHERE status = ?'), (status.value,))
|
||||
stats[status.value] = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user