Files
tw/db/chat_log_db.py

633 lines
23 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
聊天记录数据库SQLite / MySQL
每条消息独立存储按客户ID分开支持查询和展示。
支持 MySQL 连接池以提高性能。
"""
import sqlite3
import os
import threading
from queue import Queue, Empty
from datetime import datetime, timedelta
from typing import List, Dict, Optional
_DB_PATH = os.path.join(os.path.dirname(__file__), "chat_log_db", "chats.db")
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
# ========== MySQL 连接池 ==========
_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "10"))
_POOL_WAIT_TIMEOUT = float(os.getenv("MYSQL_POOL_WAIT_TIMEOUT", "10"))
_mysql_pool: Optional[Queue] = None
_pool_lock = threading.Lock()
_mysql_conn_count = 0
def _create_mysql_conn():
"""创建单个 MySQL 连接"""
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
def _init_mysql_pool():
"""初始化 MySQL 连接池(懒创建,不在启动时预建满)"""
global _mysql_pool
with _pool_lock:
if _mysql_pool is None:
_mysql_pool = Queue(maxsize=_POOL_SIZE)
def _discard_conn(conn):
"""丢弃失效连接并维护计数"""
global _mysql_conn_count
try:
conn.close()
except Exception:
pass
with _pool_lock:
if _mysql_conn_count > 0:
_mysql_conn_count -= 1
def _get_pooled_conn(timeout: float = 5.0):
"""从连接池获取连接,达到上限后阻塞等待,不再额外扩容。"""
global _mysql_pool, _mysql_conn_count
if _mysql_pool is None:
_init_mysql_pool()
with _pool_lock:
if _mysql_conn_count < _POOL_SIZE:
conn = _create_mysql_conn()
_mysql_conn_count += 1
return conn
try:
conn = _mysql_pool.get(timeout=timeout)
try:
conn.ping(reconnect=True)
except Exception:
_discard_conn(conn)
with _pool_lock:
if _mysql_conn_count < _POOL_SIZE:
conn = _create_mysql_conn()
_mysql_conn_count += 1
return conn
conn = _mysql_pool.get(timeout=timeout)
conn.ping(reconnect=True)
return conn
except Empty:
raise TimeoutError(f"MySQL连接池已耗尽pool_size={_POOL_SIZE}, wait_timeout={timeout}s")
def _return_conn(conn):
"""归还连接到池,失效连接直接丢弃。"""
global _mysql_pool
if _mysql_pool is None:
return
try:
conn.ping(reconnect=False)
_mysql_pool.put_nowait(conn)
except Exception:
_discard_conn(conn)
class _CompatResult:
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
self._rows = rows or []
self.rowcount = rowcount
self.lastrowid = lastrowid
def fetchall(self):
return self._rows
def fetchone(self):
return self._rows[0] if self._rows else None
class _PyMySQLCompatConn:
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池。"""
def __init__(self, conn, use_pool: bool = True):
self._conn = conn
self._use_pool = use_pool
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if exc_type:
try:
self._conn.rollback()
except Exception:
pass
# 归还连接到池而不是关闭
if self._use_pool:
_return_conn(self._conn)
else:
self._conn.close()
def execute(self, query: str, args=None):
cur = self._conn.cursor()
cur.execute(query, args or ())
rows = cur.fetchall() if cur.description else []
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
cur.close()
return res
def commit(self):
self._conn.commit()
def close(self):
if self._use_pool:
_return_conn(self._conn)
else:
self._conn.close()
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection:
"""获取数据库连接MySQL 使用连接池"""
if _is_mysql():
import time
last_error = None
for attempt in range(max_retries):
try:
conn = _get_pooled_conn(timeout=_POOL_WAIT_TIMEOUT)
return _PyMySQLCompatConn(conn, use_pool=True)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
continue
raise
raise last_error
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def init_db():
"""建表(首次运行时自动调用)"""
with _get_conn() as conn:
if _is_mysql():
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_logs (
id INTEGER PRIMARY KEY AUTO_INCREMENT,
customer_id VARCHAR(128) NOT NULL,
customer_name VARCHAR(255) DEFAULT '',
acc_id VARCHAR(128) DEFAULT '',
platform VARCHAR(64) DEFAULT '',
direction VARCHAR(8) NOT NULL,
message TEXT NOT NULL,
msg_type INTEGER DEFAULT 0,
timestamp DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
idx_rows = conn.execute("SHOW INDEX FROM chat_logs").fetchall()
exists = {str(r.get("Key_name", "")) for r in idx_rows}
if "idx_customer" not in exists:
conn.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)")
if "idx_ts" not in exists:
conn.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)")
if "idx_acc" not in exists:
conn.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)")
# 添加 image_urls 列(如果不存在)
try:
conn.execute("ALTER TABLE chat_logs ADD COLUMN image_urls TEXT DEFAULT ''")
except Exception:
pass # 列已存在
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id TEXT NOT NULL,
customer_name TEXT DEFAULT '',
acc_id TEXT DEFAULT '',
platform TEXT DEFAULT '',
direction TEXT NOT NULL CHECK(direction IN ('in','out')),
message TEXT NOT NULL,
msg_type INTEGER DEFAULT 0,
timestamp TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
try:
conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''")
except Exception:
pass
try:
conn.execute("ALTER TABLE chat_logs ADD COLUMN image_urls TEXT DEFAULT ''")
except Exception:
pass
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
# ---- customer_orders 表 ----
if _is_mysql():
conn.execute("""
CREATE TABLE IF NOT EXISTS customer_orders (
id INTEGER PRIMARY KEY AUTO_INCREMENT,
customer_id VARCHAR(128) NOT NULL,
acc_id VARCHAR(128) DEFAULT '',
order_id VARCHAR(64) NOT NULL,
order_status VARCHAR(64) DEFAULT '',
product_title VARCHAR(512) DEFAULT '',
amount DECIMAL(10,2) DEFAULT 0,
quantity INTEGER DEFAULT 0,
buyer_note TEXT,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
idx_rows2 = conn.execute("SHOW INDEX FROM customer_orders").fetchall()
exists2 = {str(r.get("Key_name", "")) for r in idx_rows2}
if "idx_co_customer" not in exists2:
conn.execute("CREATE INDEX idx_co_customer ON customer_orders(customer_id)")
if "idx_co_order" not in exists2:
conn.execute("CREATE UNIQUE INDEX idx_co_order ON customer_orders(order_id, order_status)")
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS customer_orders (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id TEXT NOT NULL,
acc_id TEXT DEFAULT '',
order_id TEXT NOT NULL,
order_status TEXT DEFAULT '',
product_title TEXT DEFAULT '',
amount REAL DEFAULT 0,
quantity INTEGER DEFAULT 0,
buyer_note TEXT DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_co_customer ON customer_orders(customer_id)")
conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_co_order ON customer_orders(order_id, order_status)")
conn.commit()
init_db()
# ========== 重试装饰器 ==========
def _retry_db_operation(func):
"""数据库操作重试装饰器,处理连接丢失等临时错误"""
import functools
import time
@functools.wraps(func)
def wrapper(*args, **kwargs):
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
err_str = str(e).lower()
# 判断是否为可重试的连接错误
is_conn_error = any(k in err_str for k in [
"lost connection", "gone away", "connection reset",
"can't connect", "connection refused", "2013", "2006"
])
if is_conn_error and attempt < max_retries - 1:
last_error = e
time.sleep(0.5 * (attempt + 1))
continue
raise
raise last_error
return wrapper
# ========== 写入 ==========
@_retry_db_operation
def log_message(
customer_id: str,
message: str,
direction: str, # "in" = 客户发来,"out" = 客服回复
customer_name: str = "",
acc_id: str = "", # 店铺账号ID
platform: str = "",
msg_type: int = 0,
image_urls: str = "", # 图片URL列表用\n分隔
):
"""记录一条聊天消息"""
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with _get_conn() as conn:
conn.execute(
_sql("INSERT INTO chat_logs "
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp, image_urls) "
"VALUES (?,?,?,?,?,?,?,?,?)"),
(customer_id, customer_name, acc_id, platform, direction, message, msg_type, ts, image_urls),
)
conn.commit()
# ========== 查询 ==========
def get_customers(limit: int = 100) -> List[Dict]:
"""返回所有有记录的客户列表(按最新消息时间排序)"""
with _get_conn() as conn:
rows = conn.execute("""
SELECT
customer_id,
MAX(customer_name) AS customer_name,
MAX(platform) AS platform,
COUNT(*) AS total_msgs,
SUM(direction='in') AS recv,
SUM(direction='out') AS sent,
MAX(timestamp) AS last_time
FROM chat_logs
GROUP BY customer_id
ORDER BY last_time DESC
LIMIT %s
""" if _is_mysql() else """
SELECT
customer_id,
MAX(customer_name) AS customer_name,
MAX(platform) AS platform,
COUNT(*) AS total_msgs,
SUM(direction='in') AS recv,
SUM(direction='out') AS sent,
MAX(timestamp) AS last_time
FROM chat_logs
GROUP BY customer_id
ORDER BY last_time DESC
LIMIT ?
""", (limit,)).fetchall()
return [dict(r) for r in rows]
@_retry_db_operation
def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]:
"""返回某客户的最近对话记录(按时间升序)"""
# 忽略 acc_id 过滤,实现全店铺记忆
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT * FROM (
SELECT id, direction, message, msg_type, timestamp, acc_id, image_urls
FROM chat_logs
WHERE customer_id = ?
ORDER BY timestamp DESC, id DESC
LIMIT ?
) AS recent
ORDER BY timestamp ASC, id ASC
"""), (customer_id, limit)).fetchall()
return [dict(r) for r in rows]
def get_recent_conversation(customer_id: str, acc_id: str = "", limit: int = 10) -> List[Dict]:
"""返回某客户近期对话,忽略 acc_id 过滤"""
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT id, direction, message, timestamp, acc_id
FROM chat_logs
WHERE customer_id = ?
ORDER BY id DESC
LIMIT ?
"""), (customer_id, limit)).fetchall()
out = [dict(r) for r in reversed(rows)]
return out
def get_conversation_today(customer_id: str) -> List[Dict]:
"""返回某客户今天的对话"""
today = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT id, direction, message, msg_type, timestamp
FROM chat_logs
WHERE customer_id = ? AND timestamp LIKE ?
ORDER BY timestamp ASC, id ASC
"""), (customer_id, f"{today}%")).fetchall()
return [dict(r) for r in rows]
def get_daily_stats(date: str = "") -> List[Dict]:
"""
返回指定日期各店铺的统计数据。
date 格式 'YYYY-MM-DD',默认今天。
每条记录对应一个 acc_id店铺
"""
if not date:
date = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT
acc_id,
platform,
COUNT(DISTINCT customer_id) AS unique_customers,
COUNT(*) AS total_msgs,
SUM(direction='in') AS recv,
SUM(direction='out') AS sent,
MIN(timestamp) AS first_msg,
MAX(timestamp) AS last_msg
FROM chat_logs
WHERE timestamp LIKE ?
GROUP BY acc_id
ORDER BY unique_customers DESC
"""), (f"{date}%",)).fetchall()
return [dict(r) for r in rows]
def get_daily_conversations(date: str = "") -> List[Dict]:
"""
返回指定日期每个客户的对话摘要每人最多取前5条消息用于 AI 摘要)。
"""
if not date:
date = datetime.now().strftime("%Y-%m-%d")
with _get_conn() as conn:
if _is_mysql():
rows = conn.execute(_sql("""
SELECT
acc_id,
customer_id,
MAX(customer_name) AS customer_name,
COUNT(*) AS msg_count,
GROUP_CONCAT(
CONCAT(CASE WHEN direction='in' THEN '买:' ELSE '客:' END, LEFT(message,40))
SEPARATOR ' | '
) AS snippet
FROM chat_logs
WHERE timestamp LIKE ?
GROUP BY acc_id, customer_id
ORDER BY acc_id, MAX(timestamp) DESC
"""), (f"{date}%",)).fetchall()
else:
rows = conn.execute(_sql("""
SELECT
acc_id,
customer_id,
MAX(customer_name) AS customer_name,
COUNT(*) AS msg_count,
GROUP_CONCAT(
CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40)
ELSE '客:' || SUBSTR(message,1,40) END,
' | '
) AS snippet
FROM chat_logs
WHERE timestamp LIKE ?
GROUP BY acc_id, customer_id
ORDER BY acc_id, MAX(timestamp) DESC
"""), (f"{date}%",)).fetchall()
return [dict(r) for r in rows]
def search_messages(keyword: str, customer_id: Optional[str] = None, limit: int = 50) -> List[Dict]:
"""全文搜索消息"""
if customer_id:
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT customer_id, customer_name, direction, message, timestamp
FROM chat_logs
WHERE customer_id = ? AND message LIKE ?
ORDER BY timestamp DESC LIMIT ?
"""), (customer_id, f"%{keyword}%", limit)).fetchall()
else:
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT customer_id, customer_name, direction, message, timestamp
FROM chat_logs
WHERE message LIKE ?
ORDER BY timestamp DESC LIMIT ?
"""), (f"%{keyword}%", limit)).fetchall()
return [dict(r) for r in rows]
def get_latest_messages(limit: int = 20) -> List[Dict]:
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT id, customer_id, customer_name, direction, message, timestamp
FROM chat_logs
ORDER BY id DESC LIMIT ?
"""), (limit,)).fetchall()
return [dict(r) for r in rows]
def get_waiting_customer_pool(window_minutes: int = 30) -> Dict:
"""统计最近窗口内、最后一条消息仍来自客户的待接待客户池。"""
cutoff = (datetime.now() - timedelta(minutes=max(window_minutes, 1))).strftime("%Y-%m-%d %H:%M:%S")
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT id, customer_id, acc_id, direction, timestamp
FROM chat_logs
WHERE timestamp >= ?
AND customer_id <> ''
AND customer_id <> 'unknown'
AND acc_id <> ''
ORDER BY id DESC
"""), (cutoff,)).fetchall()
latest_by_session = {}
for row in rows:
item = dict(row)
key = (str(item.get("customer_id") or ""), str(item.get("acc_id") or ""))
if key not in latest_by_session:
latest_by_session[key] = item
per_shop: Dict[str, int] = {}
waiting_sessions = 0
for item in latest_by_session.values():
if str(item.get("direction") or "") != "in":
continue
acc_id = str(item.get("acc_id") or "")
if not acc_id:
continue
per_shop[acc_id] = per_shop.get(acc_id, 0) + 1
waiting_sessions += 1
shops = [
{"acc_id": acc_id, "waiting_customers": count}
for acc_id, count in sorted(per_shop.items(), key=lambda kv: (-kv[1], kv[0]))
]
return {
"total_waiting_customers": waiting_sessions,
"shops": shops,
"window_minutes": window_minutes,
}
# ========== 订单相关 ==========
@_retry_db_operation
def upsert_order(
customer_id: str,
order_id: str,
order_status: str = "",
acc_id: str = "",
product_title: str = "",
amount: float = 0.0,
quantity: int = 0,
buyer_note: str = "",
):
"""写入或更新一条订单记录(按 order_id + order_status 去重)"""
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with _get_conn() as conn:
if _is_mysql():
conn.execute(
"INSERT INTO customer_orders "
"(customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at) "
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) "
"ON DUPLICATE KEY UPDATE customer_id=VALUES(customer_id), acc_id=VALUES(acc_id), "
"product_title=VALUES(product_title), amount=VALUES(amount), quantity=VALUES(quantity), "
"buyer_note=VALUES(buyer_note), updated_at=VALUES(updated_at)",
(customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, ts, ts),
)
else:
conn.execute(
_sql("INSERT OR REPLACE INTO customer_orders "
"(customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at) "
"VALUES (?,?,?,?,?,?,?,?,?,?)"),
(customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, ts, ts),
)
conn.commit()
@_retry_db_operation
def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]:
"""查询某客户的订单记录(按时间倒序)"""
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at
FROM customer_orders
WHERE customer_id = ?
ORDER BY updated_at DESC
LIMIT ?
"""), (customer_id, limit)).fetchall()
return [dict(r) for r in rows]
def get_order_by_id(order_id: str) -> List[Dict]:
"""按订单号查询所有状态变更记录"""
with _get_conn() as conn:
rows = conn.execute(_sql("""
SELECT customer_id, order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at
FROM customer_orders
WHERE order_id = ?
ORDER BY updated_at ASC
"""), (order_id,)).fetchall()
return [dict(r) for r in rows]