feat: migrate core data stores to MySQL with compatibility fixes
This commit is contained in:
@@ -16,6 +16,51 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
|
||||
class _CompatResult:
|
||||
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
||||
self._rows = rows or []
|
||||
self.rowcount = rowcount
|
||||
self.lastrowid = lastrowid
|
||||
|
||||
def fetchall(self):
|
||||
return self._rows
|
||||
|
||||
def fetchone(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class _PyMySQLCompatConn:
|
||||
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
|
||||
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if exc_type:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
|
||||
def execute(self, query: str, args=None):
|
||||
cur = self._conn.cursor()
|
||||
cur.execute(query, args or ())
|
||||
rows = cur.fetchall() if cur.description else []
|
||||
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
|
||||
cur.close()
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
@@ -26,7 +71,7 @@ def _sql(query: str) -> str:
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
conn = pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
@@ -36,6 +81,7 @@ def _get_conn() -> sqlite3.Connection:
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
return _PyMySQLCompatConn(conn)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -59,9 +105,14 @@ def init_db():
|
||||
timestamp DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
idx_rows = conn.execute("SHOW INDEX FROM chat_logs").fetchall()
|
||||
exists = {str(r.get("Key_name", "")) for r in idx_rows}
|
||||
if "idx_customer" not in exists:
|
||||
conn.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)")
|
||||
if "idx_ts" not in exists:
|
||||
conn.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)")
|
||||
if "idx_acc" not in exists:
|
||||
conn.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
|
||||
Reference in New Issue
Block a user