fix: reduce mysql connection pressure

This commit is contained in:
2026-03-08 12:29:49 +08:00
parent 54231cbd5c
commit 613d375845
2 changed files with 54 additions and 41 deletions

View File

@@ -20,9 +20,11 @@ _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
# ========== MySQL 连接池 ========== # ========== MySQL 连接池 ==========
_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "50")) _POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "10"))
_POOL_WAIT_TIMEOUT = float(os.getenv("MYSQL_POOL_WAIT_TIMEOUT", "10"))
_mysql_pool: Optional[Queue] = None _mysql_pool: Optional[Queue] = None
_pool_lock = threading.Lock() _pool_lock = threading.Lock()
_mysql_conn_count = 0
def _create_mysql_conn(): def _create_mysql_conn():
@@ -44,58 +46,65 @@ def _create_mysql_conn():
def _init_mysql_pool(): def _init_mysql_pool():
"""初始化 MySQL 连接池""" """初始化 MySQL 连接池(懒创建,不在启动时预建满)"""
global _mysql_pool global _mysql_pool
with _pool_lock: with _pool_lock:
if _mysql_pool is None: if _mysql_pool is None:
_mysql_pool = Queue(maxsize=_POOL_SIZE) _mysql_pool = Queue(maxsize=_POOL_SIZE)
for _ in range(_POOL_SIZE):
try:
conn = _create_mysql_conn() def _discard_conn(conn):
_mysql_pool.put(conn) """丢弃失效连接并维护计数"""
except Exception: global _mysql_conn_count
pass # 启动时连接失败不阻塞,后续会重建 try:
conn.close()
except Exception:
pass
with _pool_lock:
if _mysql_conn_count > 0:
_mysql_conn_count -= 1
def _get_pooled_conn(timeout: float = 5.0): def _get_pooled_conn(timeout: float = 5.0):
"""从连接池获取连接""" """从连接池获取连接,达到上限后阻塞等待,不再额外扩容。"""
global _mysql_pool global _mysql_pool, _mysql_conn_count
if _mysql_pool is None: if _mysql_pool is None:
_init_mysql_pool() _init_mysql_pool()
with _pool_lock:
if _mysql_conn_count < _POOL_SIZE:
conn = _create_mysql_conn()
_mysql_conn_count += 1
return conn
try: try:
conn = _mysql_pool.get(timeout=timeout) conn = _mysql_pool.get(timeout=timeout)
# 检查连接是否有效
try: try:
conn.ping(reconnect=True) conn.ping(reconnect=True)
except Exception: except Exception:
# 连接失效,创建新连接 _discard_conn(conn)
try: with _pool_lock:
conn.close() if _mysql_conn_count < _POOL_SIZE:
except Exception: conn = _create_mysql_conn()
pass _mysql_conn_count += 1
conn = _create_mysql_conn() return conn
conn = _mysql_pool.get(timeout=timeout)
conn.ping(reconnect=True)
return conn return conn
except Empty: except Empty:
# 池空了,创建新连接(不放回池) raise TimeoutError(f"MySQL连接池已耗尽pool_size={_POOL_SIZE}, wait_timeout={timeout}s")
return _create_mysql_conn()
def _return_conn(conn): def _return_conn(conn):
"""归还连接到池""" """归还连接到池,失效连接直接丢弃。"""
global _mysql_pool global _mysql_pool
if _mysql_pool is None: if _mysql_pool is None:
return return
try: try:
if _mysql_pool.qsize() < _POOL_SIZE: conn.ping(reconnect=False)
_mysql_pool.put_nowait(conn) _mysql_pool.put_nowait(conn)
else:
conn.close()
except Exception: except Exception:
try: _discard_conn(conn)
conn.close()
except Exception:
pass
class _CompatResult: class _CompatResult:
@@ -164,7 +173,7 @@ def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connect
last_error = None last_error = None
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
conn = _get_pooled_conn(timeout=5.0) conn = _get_pooled_conn(timeout=_POOL_WAIT_TIMEOUT)
return _PyMySQLCompatConn(conn, use_pool=True) return _PyMySQLCompatConn(conn, use_pool=True)
except Exception as e: except Exception as e:
last_error = e last_error = e

View File

@@ -10,6 +10,7 @@ from typing import Optional, Dict, List
from pathlib import Path from pathlib import Path
from enum import Enum from enum import Enum
import os import os
from db.chat_log_db import _get_pooled_conn, _return_conn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() _DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
@@ -45,6 +46,19 @@ class TaskPriority(Enum):
HIGH = "high" HIGH = "high"
URGENT = "urgent" URGENT = "urgent"
class _PooledMySQLConn:
"""包装 pymysql 连接close 时归还到共享连接池。"""
def __init__(self, conn):
self._conn = conn
def __getattr__(self, name):
return getattr(self._conn, name)
def close(self):
_return_conn(self._conn)
class TaskManager: class TaskManager:
"""任务管理器 - SQLite 存储""" """任务管理器 - SQLite 存储"""
@@ -139,17 +153,7 @@ class TaskManager:
def _get_conn(self): def _get_conn(self):
"""获取数据库连接""" """获取数据库连接"""
if _is_mysql(): if _is_mysql():
import pymysql return _PooledMySQLConn(_get_pooled_conn())
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
return conn return conn