feat: migrate core data stores to MySQL with compatibility fixes

This commit is contained in:
2026-02-28 19:35:51 +08:00
parent dc04db6538
commit 650b46ed99
9 changed files with 660 additions and 16 deletions

View File

@@ -16,6 +16,51 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
class _CompatResult:
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
self._rows = rows or []
self.rowcount = rowcount
self.lastrowid = lastrowid
def fetchall(self):
return self._rows
def fetchone(self):
return self._rows[0] if self._rows else None
class _PyMySQLCompatConn:
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
def __init__(self, conn):
self._conn = conn
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if exc_type:
try:
self._conn.rollback()
except Exception:
pass
self._conn.close()
def execute(self, query: str, args=None):
cur = self._conn.cursor()
cur.execute(query, args or ())
rows = cur.fetchall() if cur.description else []
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
cur.close()
return res
def commit(self):
self._conn.commit()
def close(self):
self._conn.close()
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -26,7 +71,7 @@ def _sql(query: str) -> str:
def _get_conn() -> sqlite3.Connection:
if _is_mysql():
import pymysql
return pymysql.connect(
conn = pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
@@ -36,6 +81,7 @@ def _get_conn() -> sqlite3.Connection:
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
return _PyMySQLCompatConn(conn)
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -59,9 +105,14 @@ def init_db():
timestamp DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
idx_rows = conn.execute("SHOW INDEX FROM chat_logs").fetchall()
exists = {str(r.get("Key_name", "")) for r in idx_rows}
if "idx_customer" not in exists:
conn.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)")
if "idx_ts" not in exists:
conn.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)")
if "idx_acc" not in exists:
conn.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)")
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS chat_logs (

View File

@@ -6,6 +6,17 @@ from typing import Optional, Dict, Any, List
from dataclasses import dataclass, asdict
from collections import defaultdict
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@dataclass
class CustomerProfile:
@@ -163,14 +174,64 @@ class CustomerDatabase:
self.db_path = db_path
self.customers_file = os.path.join(db_path, "customers.json")
self._ensure_db()
def _get_mysql_conn(self):
import pymysql
return pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
password=_MYSQL_PASSWORD,
database=_MYSQL_DATABASE,
charset="utf8mb4",
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
def _ensure_db(self):
if _is_mysql():
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE IF NOT EXISTS customer_profiles (
customer_id VARCHAR(128) PRIMARY KEY,
profile_json LONGTEXT NOT NULL,
last_update DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
)
cur.execute("SHOW INDEX FROM customer_profiles")
exists = {str(r.get("Key_name", "")) for r in cur.fetchall()}
if "idx_last_update" not in exists:
cur.execute("CREATE INDEX idx_last_update ON customer_profiles(last_update)")
conn.commit()
return
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
if not os.path.exists(self.customers_file):
self._save_customers({})
def _load_customers(self) -> Dict[str, dict]:
if _is_mysql():
out: Dict[str, dict] = {}
try:
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT customer_id, profile_json FROM customer_profiles")
rows = cur.fetchall()
for r in rows:
cid = str(r.get("customer_id") or "")
if not cid:
continue
try:
out[cid] = json.loads(r.get("profile_json") or "{}")
except Exception:
out[cid] = {}
except Exception:
return {}
return out
try:
with open(self.customers_file, 'r', encoding='utf-8') as f:
return json.load(f)
@@ -178,10 +239,41 @@ class CustomerDatabase:
return {}
def _save_customers(self, customers: Dict[str, dict]):
if _is_mysql():
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
for cid, data in (customers or {}).items():
cur.execute(
"""
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
VALUES (%s, %s, %s)
""",
(cid, json.dumps(data, ensure_ascii=False), now),
)
conn.commit()
return
with open(self.customers_file, 'w', encoding='utf-8') as f:
json.dump(customers, f, ensure_ascii=False, indent=2)
def get_customer(self, customer_id: str) -> CustomerProfile:
if _is_mysql():
try:
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT profile_json FROM customer_profiles WHERE customer_id=%s LIMIT 1",
(customer_id,),
)
row = cur.fetchone()
if row and row.get("profile_json"):
data = json.loads(row.get("profile_json") or "{}")
else:
data = {}
data.pop('customer_id', None)
return CustomerProfile(customer_id=customer_id, **data)
except Exception:
return CustomerProfile(customer_id=customer_id)
customers = self._load_customers()
data = customers.get(customer_id, {})
# 确保不重复传递 customer_id
@@ -190,6 +282,22 @@ class CustomerDatabase:
def save_customer(self, profile: CustomerProfile):
profile.last_update = datetime.now().isoformat()
if _is_mysql():
with self._get_mysql_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
VALUES (%s, %s, %s)
""",
(
profile.customer_id,
json.dumps(asdict(profile), ensure_ascii=False),
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
),
)
conn.commit()
return
customers = self._load_customers()
customers[profile.customer_id] = asdict(profile)
self._save_customers(customers)

View File

@@ -15,6 +15,46 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
class _CompatResult:
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
self._rows = rows or []
self.rowcount = rowcount
self.lastrowid = lastrowid
def fetchall(self):
return self._rows
def fetchone(self):
return self._rows[0] if self._rows else None
class _PyMySQLCompatConn:
def __init__(self, conn):
self._conn = conn
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if exc_type:
try:
self._conn.rollback()
except Exception:
pass
self._conn.close()
def execute(self, query: str, args=None):
cur = self._conn.cursor()
cur.execute(query, args or ())
rows = cur.fetchall() if cur.description else []
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
cur.close()
return res
def commit(self):
self._conn.commit()
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -25,7 +65,7 @@ def _sql(query: str) -> str:
def _get_conn() -> sqlite3.Connection:
if _is_mysql():
import pymysql
return pymysql.connect(
conn = pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
@@ -35,6 +75,7 @@ def _get_conn() -> sqlite3.Connection:
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
return _PyMySQLCompatConn(conn)
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row
@@ -60,10 +101,16 @@ def _init_db():
timestamp DATETIME NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
idx_rows = conn.execute("SHOW INDEX FROM deal_outcomes").fetchall()
exists = {str(r.get("Key_name", "")) for r in idx_rows}
if "idx_deal_date" not in exists:
conn.execute("CREATE INDEX idx_deal_date ON deal_outcomes(date)")
if "idx_deal_customer" not in exists:
conn.execute("CREATE INDEX idx_deal_customer ON deal_outcomes(customer_id)")
if "idx_deal_acc" not in exists:
conn.execute("CREATE INDEX idx_deal_acc ON deal_outcomes(acc_id)")
if "idx_deal_outcome" not in exists:
conn.execute("CREATE INDEX idx_deal_outcome ON deal_outcomes(outcome)")
else:
conn.execute("""
CREATE TABLE IF NOT EXISTS deal_outcomes (

View File

@@ -17,6 +17,46 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
class _CompatResult:
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
self._rows = rows or []
self.rowcount = rowcount
self.lastrowid = lastrowid
def fetchall(self):
return self._rows
def fetchone(self):
return self._rows[0] if self._rows else None
class _PyMySQLCompatConn:
def __init__(self, conn):
self._conn = conn
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if exc_type:
try:
self._conn.rollback()
except Exception:
pass
self._conn.close()
def execute(self, query: str, args=None):
cur = self._conn.cursor()
cur.execute(query, args or ())
rows = cur.fetchall() if cur.description else []
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
cur.close()
return res
def commit(self):
self._conn.commit()
def _is_mysql() -> bool:
return _DB_TYPE in ("mysql", "mariadb")
@@ -27,7 +67,7 @@ def _sql(query: str) -> str:
def _get_conn() -> sqlite3.Connection:
if _is_mysql():
import pymysql
return pymysql.connect(
conn = pymysql.connect(
host=_MYSQL_HOST,
port=_MYSQL_PORT,
user=_MYSQL_USER,
@@ -37,6 +77,7 @@ def _get_conn() -> sqlite3.Connection:
cursorclass=pymysql.cursors.DictCursor,
autocommit=False,
)
return _PyMySQLCompatConn(conn)
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
conn = sqlite3.connect(_DB_PATH)
conn.row_factory = sqlite3.Row

View File

@@ -90,9 +90,14 @@ class ImageTaskManager:
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
cursor.execute("SHOW INDEX FROM image_tasks")
exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()}
if "idx_customer" not in exists:
cursor.execute('CREATE INDEX idx_customer ON image_tasks(customer_id)')
if "idx_status" not in exists:
cursor.execute('CREATE INDEX idx_status ON image_tasks(status)')
if "idx_created" not in exists:
cursor.execute('CREATE INDEX idx_created ON image_tasks(created_at)')
conn.commit()
conn.close()
else:

View File

@@ -88,9 +88,14 @@ class TaskManager:
result TEXT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
''')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
cursor.execute("SHOW INDEX FROM tasks")
exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()}
if "idx_status" not in exists:
cursor.execute('CREATE INDEX idx_status ON tasks(status)')
if "idx_customer" not in exists:
cursor.execute('CREATE INDEX idx_customer ON tasks(customer_id)')
if "idx_created" not in exists:
cursor.execute('CREATE INDEX idx_created ON tasks(created_at)')
conn.commit()
conn.close()
else: