This commit is contained in:
2026-03-08 11:54:39 +08:00
parent 3c52061861
commit 54231cbd5c
6 changed files with 310 additions and 53 deletions

View File

@@ -1,10 +1,13 @@
"""
聊天记录数据库SQLite
聊天记录数据库SQLite / MySQL
每条消息独立存储按客户ID分开支持查询和展示。
支持 MySQL 连接池以提高性能。
"""
import sqlite3
import os
import threading
from queue import Queue, Empty
from datetime import datetime
from typing import List, Dict, Optional
@@ -16,6 +19,84 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
# ========== MySQL 连接池 ==========
_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "50"))
_mysql_pool: Optional[Queue] = None
_pool_lock = threading.Lock()
def _create_mysql_conn():
"""创建单个 MySQL 连接"""
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
def _init_mysql_pool():
"""初始化 MySQL 连接池"""
global _mysql_pool
with _pool_lock:
if _mysql_pool is None:
_mysql_pool = Queue(maxsize=_POOL_SIZE)
for _ in range(_POOL_SIZE):
try:
conn = _create_mysql_conn()
_mysql_pool.put(conn)
except Exception:
pass # 启动时连接失败不阻塞,后续会重建
def _get_pooled_conn(timeout: float = 5.0):
"""从连接池获取连接"""
global _mysql_pool
if _mysql_pool is None:
_init_mysql_pool()
try:
conn = _mysql_pool.get(timeout=timeout)
# 检查连接是否有效
try:
conn.ping(reconnect=True)
except Exception:
# 连接失效,创建新连接
try:
conn.close()
except Exception:
pass
conn = _create_mysql_conn()
return conn
except Empty:
# 池空了,创建新连接(不放回池)
return _create_mysql_conn()
def _return_conn(conn):
"""归还连接到池"""
global _mysql_pool
if _mysql_pool is None:
return
try:
if _mysql_pool.qsize() < _POOL_SIZE:
_mysql_pool.put_nowait(conn)
else:
conn.close()
except Exception:
try:
conn.close()
except Exception:
pass
class _CompatResult:
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
@@ -31,10 +112,11 @@ class _CompatResult:
class _PyMySQLCompatConn:
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池"""
def __init__(self, conn):
def __init__(self, conn, use_pool: bool = True):
self._conn = conn
self._use_pool = use_pool
def __enter__(self):
return self
@@ -45,7 +127,11 @@ class _PyMySQLCompatConn:
self._conn.rollback()
except Exception:
pass
self._conn.close()
# 归还连接到池而不是关闭
if self._use_pool:
_return_conn(self._conn)
else:
self._conn.close()
def execute(self, query: str, args=None):
cur = self._conn.cursor()
@@ -59,7 +145,10 @@ class _PyMySQLCompatConn:
self._conn.commit()
def close(self):
self._conn.close()
if self._use_pool:
_return_conn(self._conn)
else:
self._conn.close()
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -68,20 +157,22 @@ def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _get_conn() -> sqlite3.Connection:
def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection:
"""获取数据库连接MySQL 使用连接池"""
if _is_mysql():
import pymysql
conn = pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
return _PyMySQLCompatConn(conn)
import time
last_error = None
for attempt in range(max_retries):
try:
conn = _get_pooled_conn(timeout=5.0)
return _PyMySQLCompatConn(conn, use_pool=True)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
continue
raise
raise last_error
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -192,8 +283,38 @@ def init_db():
init_db()
# ========== 重试装饰器 ==========
def _retry_db_operation(func):
"""数据库操作重试装饰器,处理连接丢失等临时错误"""
import functools
import time
@functools.wraps(func)
def wrapper(*args, **kwargs):
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
err_str = str(e).lower()
# 判断是否为可重试的连接错误
is_conn_error = any(k in err_str for k in [
"lost connection", "gone away", "connection reset",
"can't connect", "connection refused", "2013", "2006"
])
if is_conn_error and attempt < max_retries - 1:
last_error = e
time.sleep(0.5 * (attempt + 1))
continue
raise
raise last_error
return wrapper
# ========== 写入 ==========
@_retry_db_operation
def log_message(
customer_id: str,
message: str,
@@ -251,6 +372,7 @@ def get_customers(limit: int = 100) -> List[Dict]:
return [dict(r) for r in rows]
@_retry_db_operation
def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]:
"""返回某客户的最近对话记录(按时间升序)"""
# 忽略 acc_id 过滤,实现全店铺记忆
@@ -398,6 +520,7 @@ def get_latest_messages(limit: int = 20) -> List[Dict]:
# ========== 订单相关 ==========
@_retry_db_operation
def upsert_order(
customer_id: str,
order_id: str,
@@ -431,6 +554,7 @@ def upsert_order(
conn.commit()
@_retry_db_operation
def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]:
"""查询某客户的订单记录(按时间倒序)"""
with _get_conn() as conn: