This commit is contained in:
2026-03-08 11:54:39 +08:00
parent 3c52061861
commit 54231cbd5c
6 changed files with 310 additions and 53 deletions

View File

@@ -1,10 +1,13 @@
"""
聊天记录数据库SQLite
聊天记录数据库SQLite / MySQL
每条消息独立存储按客户ID分开支持查询和展示。
支持 MySQL 连接池以提高性能。
"""
import sqlite3
import os
import threading
from queue import Queue, Empty
from datetime import datetime
from typing import List, Dict, Optional
@@ -16,6 +19,84 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
# ========== MySQL 连接池 ==========
_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "50"))
_mysql_pool: Optional[Queue] = None
_pool_lock = threading.Lock()
def _create_mysql_conn():
"""创建单个 MySQL 连接"""
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
def _init_mysql_pool():
"""初始化 MySQL 连接池"""
global _mysql_pool
with _pool_lock:
if _mysql_pool is None:
_mysql_pool = Queue(maxsize=_POOL_SIZE)
for _ in range(_POOL_SIZE):
try:
conn = _create_mysql_conn()
_mysql_pool.put(conn)
except Exception:
pass # 启动时连接失败不阻塞,后续会重建
def _get_pooled_conn(timeout: float = 5.0):
"""从连接池获取连接"""
global _mysql_pool
if _mysql_pool is None:
_init_mysql_pool()
try:
conn = _mysql_pool.get(timeout=timeout)
# 检查连接是否有效
try:
conn.ping(reconnect=True)
except Exception:
# 连接失效,创建新连接
try:
conn.close()
except Exception:
pass
conn = _create_mysql_conn()
return conn
except Empty:
# 池空了,创建新连接(不放回池)
return _create_mysql_conn()
def _return_conn(conn):
"""归还连接到池"""
global _mysql_pool
if _mysql_pool is None:
return
try:
if _mysql_pool.qsize() < _POOL_SIZE:
_mysql_pool.put_nowait(conn)
else:
conn.close()
except Exception:
try:
conn.close()
except Exception:
pass
class _CompatResult:
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
@@ -31,10 +112,11 @@ class _CompatResult:
class _PyMySQLCompatConn:
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池"""
def __init__(self, conn):
def __init__(self, conn, use_pool: bool = True):
self._conn = conn
self._use_pool = use_pool
def __enter__(self):
return self
@@ -45,7 +127,11 @@ class _PyMySQLCompatConn:
self._conn.rollback()
except Exception:
pass
self._conn.close()
# 归还连接到池而不是关闭
if self._use_pool:
_return_conn(self._conn)
else:
self._conn.close()
def execute(self, query: str, args=None):
cur = self._conn.cursor()
@@ -59,7 +145,10 @@ class _PyMySQLCompatConn:
self._conn.commit()
def close(self):
self._conn.close()
if self._use_pool:
_return_conn(self._conn)
else:
self._conn.close()
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -68,20 +157,22 @@ def _sql(query: str) -> str:
return query.replace("?", "%s") if _is_mysql() else query
def _get_conn() -> sqlite3.Connection:
def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection:
"""获取数据库连接MySQL 使用连接池"""
if _is_mysql():
import pymysql
conn = pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
return _PyMySQLCompatConn(conn)
import time
last_error = None
for attempt in range(max_retries):
try:
conn = _get_pooled_conn(timeout=5.0)
return _PyMySQLCompatConn(conn, use_pool=True)
except Exception as e:
last_error = e
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
continue
raise
raise last_error
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -192,8 +283,38 @@ def init_db():
init_db()
# ========== 重试装饰器 ==========
def _retry_db_operation(func):
"""数据库操作重试装饰器,处理连接丢失等临时错误"""
import functools
import time
@functools.wraps(func)
def wrapper(*args, **kwargs):
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
err_str = str(e).lower()
# 判断是否为可重试的连接错误
is_conn_error = any(k in err_str for k in [
"lost connection", "gone away", "connection reset",
"can't connect", "connection refused", "2013", "2006"
])
if is_conn_error and attempt < max_retries - 1:
last_error = e
time.sleep(0.5 * (attempt + 1))
continue
raise
raise last_error
return wrapper
# ========== 写入 ==========
@_retry_db_operation
def log_message(
customer_id: str,
message: str,
@@ -251,6 +372,7 @@ def get_customers(limit: int = 100) -> List[Dict]:
return [dict(r) for r in rows]
@_retry_db_operation
def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]:
"""返回某客户的最近对话记录(按时间升序)"""
# 忽略 acc_id 过滤,实现全店铺记忆
@@ -398,6 +520,7 @@ def get_latest_messages(limit: int = 20) -> List[Dict]:
# ========== 订单相关 ==========
@_retry_db_operation
def upsert_order(
customer_id: str,
order_id: str,
@@ -431,6 +554,7 @@ def upsert_order(
conn.commit()
@_retry_db_operation
def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]:
"""查询某客户的订单记录(按时间倒序)"""
with _get_conn() as conn:

View File

@@ -13,6 +13,9 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
# 复用 chat_log_db 的连接池
from db.chat_log_db import _get_pooled_conn, _return_conn
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -171,6 +174,23 @@ class CustomerProfile:
self.image_analysis_history = []
class _PooledMySQLConn:
"""包装 pymysql 连接,支持连接池归还"""
def __init__(self, conn):
self._conn = conn
def __enter__(self):
return self._conn
def __exit__(self, exc_type, exc, tb):
if exc_type:
try:
self._conn.rollback()
except Exception:
pass
_return_conn(self._conn)
class CustomerDatabase:
"""客户数据库"""
@@ -180,17 +200,8 @@ class CustomerDatabase:
self._ensure_db()
def _get_mysql_conn(self):
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
"""从连接池获取 MySQL 连接"""
return _PooledMySQLConn(_get_pooled_conn(timeout=5.0))
def _ensure_db(self):
if _is_mysql():
@@ -284,24 +295,41 @@ class CustomerDatabase:
data.pop('customer_id', None)
return CustomerProfile(customer_id=customer_id, **data)
def save_customer(self, profile: CustomerProfile):
def save_customer(self, profile: CustomerProfile, max_retries: int = 3):
"""保存客户画像(带重试机制)"""
import time
profile.last_update = datetime.now().isoformat()
if _is_mysql():
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
VALUES (%s, %s, %s)
""",
(
profile.customer_id,
json.dumps(asdict(profile), ensure_ascii=False),
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
)
conn.commit()
return
last_error = None
for attempt in range(max_retries):
try:
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
VALUES (%s, %s, %s)
""",
(
profile.customer_id,
json.dumps(asdict(profile), ensure_ascii=False),
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
)
conn.commit()
return
except Exception as e:
last_error = e
err_str = str(e).lower()
is_conn_error = any(k in err_str for k in [
"lost connection", "gone away", "connection reset",
"can't connect", "connection refused", "2013", "2006"
])
if is_conn_error and attempt < max_retries - 1:
time.sleep(0.5 * (attempt + 1))
continue
raise
raise last_error
customers = self._load_customers()
customers[profile.customer_id] = asdict(profile)
self._save_customers(customers)

BIN
db/image_tasks.db Normal file

Binary file not shown.