newtw66
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"""
|
||||
聊天记录数据库(SQLite)
|
||||
聊天记录数据库(SQLite / MySQL)
|
||||
每条消息独立存储,按客户ID分开,支持查询和展示。
|
||||
支持 MySQL 连接池以提高性能。
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import os
|
||||
import threading
|
||||
from queue import Queue, Empty
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
@@ -16,6 +19,84 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
# ========== MySQL 连接池 ==========
|
||||
_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "50"))
|
||||
_mysql_pool: Optional[Queue] = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def _create_mysql_conn():
|
||||
"""创建单个 MySQL 连接"""
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
connect_timeout=10,
|
||||
read_timeout=30,
|
||||
write_timeout=30,
|
||||
)
|
||||
|
||||
|
||||
def _init_mysql_pool():
|
||||
"""初始化 MySQL 连接池"""
|
||||
global _mysql_pool
|
||||
with _pool_lock:
|
||||
if _mysql_pool is None:
|
||||
_mysql_pool = Queue(maxsize=_POOL_SIZE)
|
||||
for _ in range(_POOL_SIZE):
|
||||
try:
|
||||
conn = _create_mysql_conn()
|
||||
_mysql_pool.put(conn)
|
||||
except Exception:
|
||||
pass # 启动时连接失败不阻塞,后续会重建
|
||||
|
||||
|
||||
def _get_pooled_conn(timeout: float = 5.0):
|
||||
"""从连接池获取连接"""
|
||||
global _mysql_pool
|
||||
if _mysql_pool is None:
|
||||
_init_mysql_pool()
|
||||
|
||||
try:
|
||||
conn = _mysql_pool.get(timeout=timeout)
|
||||
# 检查连接是否有效
|
||||
try:
|
||||
conn.ping(reconnect=True)
|
||||
except Exception:
|
||||
# 连接失效,创建新连接
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
conn = _create_mysql_conn()
|
||||
return conn
|
||||
except Empty:
|
||||
# 池空了,创建新连接(不放回池)
|
||||
return _create_mysql_conn()
|
||||
|
||||
|
||||
def _return_conn(conn):
|
||||
"""归还连接到池"""
|
||||
global _mysql_pool
|
||||
if _mysql_pool is None:
|
||||
return
|
||||
try:
|
||||
if _mysql_pool.qsize() < _POOL_SIZE:
|
||||
_mysql_pool.put_nowait(conn)
|
||||
else:
|
||||
conn.close()
|
||||
except Exception:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class _CompatResult:
|
||||
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
||||
@@ -31,10 +112,11 @@ class _CompatResult:
|
||||
|
||||
|
||||
class _PyMySQLCompatConn:
|
||||
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
|
||||
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池。"""
|
||||
|
||||
def __init__(self, conn):
|
||||
def __init__(self, conn, use_pool: bool = True):
|
||||
self._conn = conn
|
||||
self._use_pool = use_pool
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -45,7 +127,11 @@ class _PyMySQLCompatConn:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
# 归还连接到池而不是关闭
|
||||
if self._use_pool:
|
||||
_return_conn(self._conn)
|
||||
else:
|
||||
self._conn.close()
|
||||
|
||||
def execute(self, query: str, args=None):
|
||||
cur = self._conn.cursor()
|
||||
@@ -59,7 +145,10 @@ class _PyMySQLCompatConn:
|
||||
self._conn.commit()
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
if self._use_pool:
|
||||
_return_conn(self._conn)
|
||||
else:
|
||||
self._conn.close()
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
@@ -68,20 +157,22 @@ def _sql(query: str) -> str:
|
||||
return query.replace("?", "%s") if _is_mysql() else query
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection:
|
||||
"""获取数据库连接,MySQL 使用连接池"""
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
conn = pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
return _PyMySQLCompatConn(conn)
|
||||
import time
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn = _get_pooled_conn(timeout=5.0)
|
||||
return _PyMySQLCompatConn(conn, use_pool=True)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_error
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -192,8 +283,38 @@ def init_db():
|
||||
init_db()
|
||||
|
||||
|
||||
# ========== 重试装饰器 ==========
|
||||
|
||||
def _retry_db_operation(func):
|
||||
"""数据库操作重试装饰器,处理连接丢失等临时错误"""
|
||||
import functools
|
||||
import time
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
max_retries = 3
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
err_str = str(e).lower()
|
||||
# 判断是否为可重试的连接错误
|
||||
is_conn_error = any(k in err_str for k in [
|
||||
"lost connection", "gone away", "connection reset",
|
||||
"can't connect", "connection refused", "2013", "2006"
|
||||
])
|
||||
if is_conn_error and attempt < max_retries - 1:
|
||||
last_error = e
|
||||
time.sleep(0.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_error
|
||||
return wrapper
|
||||
|
||||
|
||||
# ========== 写入 ==========
|
||||
|
||||
@_retry_db_operation
|
||||
def log_message(
|
||||
customer_id: str,
|
||||
message: str,
|
||||
@@ -251,6 +372,7 @@ def get_customers(limit: int = 100) -> List[Dict]:
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
@_retry_db_operation
|
||||
def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]:
|
||||
"""返回某客户的最近对话记录(按时间升序)"""
|
||||
# 忽略 acc_id 过滤,实现全店铺记忆
|
||||
@@ -398,6 +520,7 @@ def get_latest_messages(limit: int = 20) -> List[Dict]:
|
||||
|
||||
# ========== 订单相关 ==========
|
||||
|
||||
@_retry_db_operation
|
||||
def upsert_order(
|
||||
customer_id: str,
|
||||
order_id: str,
|
||||
@@ -431,6 +554,7 @@ def upsert_order(
|
||||
conn.commit()
|
||||
|
||||
|
||||
@_retry_db_operation
|
||||
def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]:
|
||||
"""查询某客户的订单记录(按时间倒序)"""
|
||||
with _get_conn() as conn:
|
||||
|
||||
@@ -13,6 +13,9 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
# 复用 chat_log_db 的连接池
|
||||
from db.chat_log_db import _get_pooled_conn, _return_conn
|
||||
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
@@ -171,6 +174,23 @@ class CustomerProfile:
|
||||
self.image_analysis_history = []
|
||||
|
||||
|
||||
class _PooledMySQLConn:
|
||||
"""包装 pymysql 连接,支持连接池归还"""
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
def __enter__(self):
|
||||
return self._conn
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if exc_type:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
_return_conn(self._conn)
|
||||
|
||||
|
||||
class CustomerDatabase:
|
||||
"""客户数据库"""
|
||||
|
||||
@@ -180,17 +200,8 @@ class CustomerDatabase:
|
||||
self._ensure_db()
|
||||
|
||||
def _get_mysql_conn(self):
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
"""从连接池获取 MySQL 连接"""
|
||||
return _PooledMySQLConn(_get_pooled_conn(timeout=5.0))
|
||||
|
||||
def _ensure_db(self):
|
||||
if _is_mysql():
|
||||
@@ -284,24 +295,41 @@ class CustomerDatabase:
|
||||
data.pop('customer_id', None)
|
||||
return CustomerProfile(customer_id=customer_id, **data)
|
||||
|
||||
def save_customer(self, profile: CustomerProfile):
|
||||
def save_customer(self, profile: CustomerProfile, max_retries: int = 3):
|
||||
"""保存客户画像(带重试机制)"""
|
||||
import time
|
||||
profile.last_update = datetime.now().isoformat()
|
||||
if _is_mysql():
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(
|
||||
profile.customer_id,
|
||||
json.dumps(asdict(profile), ensure_ascii=False),
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(
|
||||
profile.customer_id,
|
||||
json.dumps(asdict(profile), ensure_ascii=False),
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
err_str = str(e).lower()
|
||||
is_conn_error = any(k in err_str for k in [
|
||||
"lost connection", "gone away", "connection reset",
|
||||
"can't connect", "connection refused", "2013", "2006"
|
||||
])
|
||||
if is_conn_error and attempt < max_retries - 1:
|
||||
time.sleep(0.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_error
|
||||
customers = self._load_customers()
|
||||
customers[profile.customer_id] = asdict(profile)
|
||||
self._save_customers(customers)
|
||||
|
||||
BIN
db/image_tasks.db
Normal file
BIN
db/image_tasks.db
Normal file
Binary file not shown.
Reference in New Issue
Block a user