newtw66
This commit is contained in:
@@ -13,6 +13,9 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
# 复用 chat_log_db 的连接池
|
||||
from db.chat_log_db import _get_pooled_conn, _return_conn
|
||||
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
@@ -171,6 +174,23 @@ class CustomerProfile:
|
||||
self.image_analysis_history = []
|
||||
|
||||
|
||||
class _PooledMySQLConn:
|
||||
"""包装 pymysql 连接,支持连接池归还"""
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
def __enter__(self):
|
||||
return self._conn
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if exc_type:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
_return_conn(self._conn)
|
||||
|
||||
|
||||
class CustomerDatabase:
|
||||
"""客户数据库"""
|
||||
|
||||
@@ -180,17 +200,8 @@ class CustomerDatabase:
|
||||
self._ensure_db()
|
||||
|
||||
def _get_mysql_conn(self):
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
"""从连接池获取 MySQL 连接"""
|
||||
return _PooledMySQLConn(_get_pooled_conn(timeout=5.0))
|
||||
|
||||
def _ensure_db(self):
|
||||
if _is_mysql():
|
||||
@@ -284,24 +295,41 @@ class CustomerDatabase:
|
||||
data.pop('customer_id', None)
|
||||
return CustomerProfile(customer_id=customer_id, **data)
|
||||
|
||||
def save_customer(self, profile: CustomerProfile):
|
||||
def save_customer(self, profile: CustomerProfile, max_retries: int = 3):
|
||||
"""保存客户画像(带重试机制)"""
|
||||
import time
|
||||
profile.last_update = datetime.now().isoformat()
|
||||
if _is_mysql():
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(
|
||||
profile.customer_id,
|
||||
json.dumps(asdict(profile), ensure_ascii=False),
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(
|
||||
profile.customer_id,
|
||||
json.dumps(asdict(profile), ensure_ascii=False),
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
err_str = str(e).lower()
|
||||
is_conn_error = any(k in err_str for k in [
|
||||
"lost connection", "gone away", "connection reset",
|
||||
"can't connect", "connection refused", "2013", "2006"
|
||||
])
|
||||
if is_conn_error and attempt < max_retries - 1:
|
||||
time.sleep(0.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_error
|
||||
customers = self._load_customers()
|
||||
customers[profile.customer_id] = asdict(profile)
|
||||
self._save_customers(customers)
|
||||
|
||||
Reference in New Issue
Block a user