This commit is contained in:
2026-03-08 11:54:39 +08:00
parent 3c52061861
commit 54231cbd5c
6 changed files with 310 additions and 53 deletions

View File

@@ -13,6 +13,9 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
# 复用 chat_log_db 的连接池
from db.chat_log_db import _get_pooled_conn, _return_conn
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -171,6 +174,23 @@ class CustomerProfile:
self.image_analysis_history = []
class _PooledMySQLConn:
"""包装 pymysql 连接,支持连接池归还"""
def __init__(self, conn):
self._conn = conn
def __enter__(self):
return self._conn
def __exit__(self, exc_type, exc, tb):
if exc_type:
try:
self._conn.rollback()
except Exception:
pass
_return_conn(self._conn)
class CustomerDatabase:
"""客户数据库"""
@@ -180,17 +200,8 @@ class CustomerDatabase:
self._ensure_db()
def _get_mysql_conn(self):
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
"""从连接池获取 MySQL 连接"""
return _PooledMySQLConn(_get_pooled_conn(timeout=5.0))
def _ensure_db(self):
if _is_mysql():
@@ -284,24 +295,41 @@ class CustomerDatabase:
data.pop('customer_id', None)
return CustomerProfile(customer_id=customer_id, **data)
def save_customer(self, profile: CustomerProfile):
def save_customer(self, profile: CustomerProfile, max_retries: int = 3):
"""保存客户画像(带重试机制)"""
import time
profile.last_update = datetime.now().isoformat()
if _is_mysql():
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
VALUES (%s, %s, %s)
""",
(
profile.customer_id,
json.dumps(asdict(profile), ensure_ascii=False),
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
)
conn.commit()
return
last_error = None
for attempt in range(max_retries):
try:
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
VALUES (%s, %s, %s)
""",
(
profile.customer_id,
json.dumps(asdict(profile), ensure_ascii=False),
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
)
conn.commit()
return
except Exception as e:
last_error = e
err_str = str(e).lower()
is_conn_error = any(k in err_str for k in [
"lost connection", "gone away", "connection reset",
"can't connect", "connection refused", "2013", "2006"
])
if is_conn_error and attempt < max_retries - 1:
time.sleep(0.5 * (attempt + 1))
continue
raise
raise last_error
customers = self._load_customers()
customers[profile.customer_id] = asdict(profile)
self._save_customers(customers)