feat: 添加 AI Agent 对话测试工具 + 代码优化
主要变更: - 新增 tests/test_ai_chat.py: AI Agent 对话测试工具 - 优化 core/pydantic_ai_agent.py 和 db/chat_log_db.py - 清理归档文件,更新文档 Made-with: Cursor
This commit is contained in:
@@ -9,9 +9,33 @@ from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
_DB_PATH = os.path.join(os.path.dirname(__file__), "chat_log_db", "chats.db")
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -21,27 +45,44 @@ def _get_conn() -> sqlite3.Connection:
|
||||
def init_db():
|
||||
"""建表(首次运行时自动调用)"""
|
||||
with _get_conn() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT DEFAULT '',
|
||||
acc_id TEXT DEFAULT '',
|
||||
platform TEXT DEFAULT '',
|
||||
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
|
||||
message TEXT NOT NULL,
|
||||
msg_type INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
# 兼容旧表:若缺少 acc_id 列则补上(必须在创建该列索引之前)
|
||||
try:
|
||||
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
|
||||
except Exception:
|
||||
pass
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
if _is_mysql():
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
||||
customer_id VARCHAR(128) NOT NULL,
|
||||
customer_name VARCHAR(255) DEFAULT '',
|
||||
acc_id VARCHAR(128) DEFAULT '',
|
||||
platform VARCHAR(64) DEFAULT '',
|
||||
direction VARCHAR(8) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
msg_type INTEGER DEFAULT 0,
|
||||
timestamp DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
customer_id TEXT NOT NULL,
|
||||
customer_name TEXT DEFAULT '',
|
||||
acc_id TEXT DEFAULT '',
|
||||
platform TEXT DEFAULT '',
|
||||
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
|
||||
message TEXT NOT NULL,
|
||||
msg_type INTEGER DEFAULT 0,
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
try:
|
||||
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
|
||||
except Exception:
|
||||
pass
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -63,9 +104,9 @@ def log_message(
|
||||
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with _get_conn() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO chat_logs "
|
||||
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
|
||||
"VALUES (?,?,?,?,?,?,?,?)",
|
||||
_sql("INSERT INTO chat_logs "
|
||||
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
|
||||
"VALUES (?,?,?,?,?,?,?,?)"),
|
||||
(customer_id, customer_name, acc_id, platform, direction, message, msg_type, ts),
|
||||
)
|
||||
conn.commit()
|
||||
@@ -77,6 +118,19 @@ def get_customers(limit: int = 100) -> List[Dict]:
|
||||
"""返回所有有记录的客户列表(按最新消息时间排序)"""
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
MAX(platform) AS platform,
|
||||
COUNT(*) AS total_msgs,
|
||||
SUM(direction='in') AS recv,
|
||||
SUM(direction='out') AS sent,
|
||||
MAX(timestamp) AS last_time
|
||||
FROM chat_logs
|
||||
GROUP BY customer_id
|
||||
ORDER BY last_time DESC
|
||||
LIMIT %s
|
||||
""" if _is_mysql() else """
|
||||
SELECT
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
@@ -96,13 +150,13 @@ def get_customers(limit: int = 100) -> List[Dict]:
|
||||
def get_conversation(customer_id: str, limit: int = 200) -> List[Dict]:
|
||||
"""返回某客户的全部对话记录(按时间升序)"""
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, msg_type, timestamp, acc_id
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ?
|
||||
ORDER BY timestamp ASC, id ASC
|
||||
LIMIT ?
|
||||
""", (customer_id, limit)).fetchall()
|
||||
"""), (customer_id, limit)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -110,21 +164,21 @@ def get_recent_conversation(customer_id: str, acc_id: str = "", limit: int = 10)
|
||||
"""返回某客户近期对话(同店铺),用于企微推送保持连贯"""
|
||||
with _get_conn() as conn:
|
||||
if acc_id:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, timestamp, acc_id
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ? AND acc_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
""", (customer_id, acc_id, limit)).fetchall()
|
||||
"""), (customer_id, acc_id, limit)).fetchall()
|
||||
else:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, timestamp, acc_id
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
""", (customer_id, limit)).fetchall()
|
||||
"""), (customer_id, limit)).fetchall()
|
||||
out = [dict(r) for r in reversed(rows)]
|
||||
return out
|
||||
|
||||
@@ -133,12 +187,12 @@ def get_conversation_today(customer_id: str) -> List[Dict]:
|
||||
"""返回某客户今天的对话"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, direction, message, msg_type, timestamp
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ? AND timestamp LIKE ?
|
||||
ORDER BY timestamp ASC, id ASC
|
||||
""", (customer_id, f"{today}%")).fetchall()
|
||||
"""), (customer_id, f"{today}%")).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -151,7 +205,7 @@ def get_daily_stats(date: str = "") -> List[Dict]:
|
||||
if not date:
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT
|
||||
acc_id,
|
||||
platform,
|
||||
@@ -165,7 +219,7 @@ def get_daily_stats(date: str = "") -> List[Dict]:
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id
|
||||
ORDER BY unique_customers DESC
|
||||
""", (f"{date}%",)).fetchall()
|
||||
"""), (f"{date}%",)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -176,22 +230,39 @@ def get_daily_conversations(date: str = "") -> List[Dict]:
|
||||
if not date:
|
||||
date = datetime.now().strftime("%Y-%m-%d")
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
SELECT
|
||||
acc_id,
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
COUNT(*) AS msg_count,
|
||||
GROUP_CONCAT(
|
||||
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
|
||||
ELSE '客:' || SUBSTR(message,1,40) END,
|
||||
' | '
|
||||
) AS snippet
|
||||
FROM chat_logs
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id, customer_id
|
||||
ORDER BY acc_id, MAX(timestamp) DESC
|
||||
""", (f"{date}%",)).fetchall()
|
||||
if _is_mysql():
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT
|
||||
acc_id,
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
COUNT(*) AS msg_count,
|
||||
GROUP_CONCAT(
|
||||
CONCAT(CASE WHEN direction='in' THEN '买:' ELSE '客:' END, LEFT(message,40))
|
||||
SEPARATOR ' | '
|
||||
) AS snippet
|
||||
FROM chat_logs
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id, customer_id
|
||||
ORDER BY acc_id, MAX(timestamp) DESC
|
||||
"""), (f"{date}%",)).fetchall()
|
||||
else:
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT
|
||||
acc_id,
|
||||
customer_id,
|
||||
MAX(customer_name) AS customer_name,
|
||||
COUNT(*) AS msg_count,
|
||||
GROUP_CONCAT(
|
||||
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
|
||||
ELSE '客:' || SUBSTR(message,1,40) END,
|
||||
' | '
|
||||
) AS snippet
|
||||
FROM chat_logs
|
||||
WHERE timestamp LIKE ?
|
||||
GROUP BY acc_id, customer_id
|
||||
ORDER BY acc_id, MAX(timestamp) DESC
|
||||
"""), (f"{date}%",)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@@ -199,18 +270,28 @@ def search_messages(keyword: str, customer_id: Optional[str] = None, limit: int
|
||||
"""全文搜索消息"""
|
||||
if customer_id:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT customer_id, customer_name, direction, message, timestamp
|
||||
FROM chat_logs
|
||||
WHERE customer_id = ? AND message LIKE ?
|
||||
ORDER BY timestamp DESC LIMIT ?
|
||||
""", (customer_id, f"%{keyword}%", limit)).fetchall()
|
||||
"""), (customer_id, f"%{keyword}%", limit)).fetchall()
|
||||
else:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute("""
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT customer_id, customer_name, direction, message, timestamp
|
||||
FROM chat_logs
|
||||
WHERE message LIKE ?
|
||||
ORDER BY timestamp DESC LIMIT ?
|
||||
""", (f"%{keyword}%", limit)).fetchall()
|
||||
"""), (f"%{keyword}%", limit)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def get_latest_messages(limit: int = 20) -> List[Dict]:
|
||||
with _get_conn() as conn:
|
||||
rows = conn.execute(_sql("""
|
||||
SELECT id, customer_id, customer_name, direction, message, timestamp
|
||||
FROM chat_logs
|
||||
ORDER BY id DESC LIMIT ?
|
||||
"""), (limit,)).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
Reference in New Issue
Block a user