feat: migrate core data stores to MySQL with compatibility fixes
This commit is contained in:
@@ -16,6 +16,51 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
|
||||
class _CompatResult:
|
||||
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
||||
self._rows = rows or []
|
||||
self.rowcount = rowcount
|
||||
self.lastrowid = lastrowid
|
||||
|
||||
def fetchall(self):
|
||||
return self._rows
|
||||
|
||||
def fetchone(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class _PyMySQLCompatConn:
|
||||
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
|
||||
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if exc_type:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
|
||||
def execute(self, query: str, args=None):
|
||||
cur = self._conn.cursor()
|
||||
cur.execute(query, args or ())
|
||||
rows = cur.fetchall() if cur.description else []
|
||||
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
|
||||
cur.close()
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def close(self):
|
||||
self._conn.close()
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
@@ -26,7 +71,7 @@ def _sql(query: str) -> str:
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
conn = pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
@@ -36,6 +81,7 @@ def _get_conn() -> sqlite3.Connection:
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
return _PyMySQLCompatConn(conn)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -59,9 +105,14 @@ def init_db():
|
||||
timestamp DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)")
|
||||
idx_rows = conn.execute("SHOW INDEX FROM chat_logs").fetchall()
|
||||
exists = {str(r.get("Key_name", "")) for r in idx_rows}
|
||||
if "idx_customer" not in exists:
|
||||
conn.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)")
|
||||
if "idx_ts" not in exists:
|
||||
conn.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)")
|
||||
if "idx_acc" not in exists:
|
||||
conn.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_logs (
|
||||
|
||||
@@ -6,6 +6,17 @@ from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomerProfile:
|
||||
@@ -163,14 +174,64 @@ class CustomerDatabase:
|
||||
self.db_path = db_path
|
||||
self.customers_file = os.path.join(db_path, "customers.json")
|
||||
self._ensure_db()
|
||||
|
||||
def _get_mysql_conn(self):
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
def _ensure_db(self):
|
||||
if _is_mysql():
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS customer_profiles (
|
||||
customer_id VARCHAR(128) PRIMARY KEY,
|
||||
profile_json LONGTEXT NOT NULL,
|
||||
last_update DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""
|
||||
)
|
||||
cur.execute("SHOW INDEX FROM customer_profiles")
|
||||
exists = {str(r.get("Key_name", "")) for r in cur.fetchall()}
|
||||
if "idx_last_update" not in exists:
|
||||
cur.execute("CREATE INDEX idx_last_update ON customer_profiles(last_update)")
|
||||
conn.commit()
|
||||
return
|
||||
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
if not os.path.exists(self.customers_file):
|
||||
self._save_customers({})
|
||||
|
||||
def _load_customers(self) -> Dict[str, dict]:
|
||||
if _is_mysql():
|
||||
out: Dict[str, dict] = {}
|
||||
try:
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT customer_id, profile_json FROM customer_profiles")
|
||||
rows = cur.fetchall()
|
||||
for r in rows:
|
||||
cid = str(r.get("customer_id") or "")
|
||||
if not cid:
|
||||
continue
|
||||
try:
|
||||
out[cid] = json.loads(r.get("profile_json") or "{}")
|
||||
except Exception:
|
||||
out[cid] = {}
|
||||
except Exception:
|
||||
return {}
|
||||
return out
|
||||
try:
|
||||
with open(self.customers_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
@@ -178,10 +239,41 @@ class CustomerDatabase:
|
||||
return {}
|
||||
|
||||
def _save_customers(self, customers: Dict[str, dict]):
|
||||
if _is_mysql():
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
for cid, data in (customers or {}).items():
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(cid, json.dumps(data, ensure_ascii=False), now),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
with open(self.customers_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(customers, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_customer(self, customer_id: str) -> CustomerProfile:
|
||||
if _is_mysql():
|
||||
try:
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT profile_json FROM customer_profiles WHERE customer_id=%s LIMIT 1",
|
||||
(customer_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row.get("profile_json"):
|
||||
data = json.loads(row.get("profile_json") or "{}")
|
||||
else:
|
||||
data = {}
|
||||
data.pop('customer_id', None)
|
||||
return CustomerProfile(customer_id=customer_id, **data)
|
||||
except Exception:
|
||||
return CustomerProfile(customer_id=customer_id)
|
||||
customers = self._load_customers()
|
||||
data = customers.get(customer_id, {})
|
||||
# 确保不重复传递 customer_id
|
||||
@@ -190,6 +282,22 @@ class CustomerDatabase:
|
||||
|
||||
def save_customer(self, profile: CustomerProfile):
|
||||
profile.last_update = datetime.now().isoformat()
|
||||
if _is_mysql():
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(
|
||||
profile.customer_id,
|
||||
json.dumps(asdict(profile), ensure_ascii=False),
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
customers = self._load_customers()
|
||||
customers[profile.customer_id] = asdict(profile)
|
||||
self._save_customers(customers)
|
||||
|
||||
@@ -15,6 +15,46 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
|
||||
class _CompatResult:
|
||||
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
||||
self._rows = rows or []
|
||||
self.rowcount = rowcount
|
||||
self.lastrowid = lastrowid
|
||||
|
||||
def fetchall(self):
|
||||
return self._rows
|
||||
|
||||
def fetchone(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class _PyMySQLCompatConn:
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if exc_type:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
|
||||
def execute(self, query: str, args=None):
|
||||
cur = self._conn.cursor()
|
||||
cur.execute(query, args or ())
|
||||
rows = cur.fetchall() if cur.description else []
|
||||
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
|
||||
cur.close()
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
@@ -25,7 +65,7 @@ def _sql(query: str) -> str:
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
conn = pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
@@ -35,6 +75,7 @@ def _get_conn() -> sqlite3.Connection:
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
return _PyMySQLCompatConn(conn)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@@ -60,10 +101,16 @@ def _init_db():
|
||||
timestamp DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)")
|
||||
idx_rows = conn.execute("SHOW INDEX FROM deal_outcomes").fetchall()
|
||||
exists = {str(r.get("Key_name", "")) for r in idx_rows}
|
||||
if "idx_deal_date" not in exists:
|
||||
conn.execute("CREATE INDEX idx_deal_date ON deal_outcomes(date)")
|
||||
if "idx_deal_customer" not in exists:
|
||||
conn.execute("CREATE INDEX idx_deal_customer ON deal_outcomes(customer_id)")
|
||||
if "idx_deal_acc" not in exists:
|
||||
conn.execute("CREATE INDEX idx_deal_acc ON deal_outcomes(acc_id)")
|
||||
if "idx_deal_outcome" not in exists:
|
||||
conn.execute("CREATE INDEX idx_deal_outcome ON deal_outcomes(outcome)")
|
||||
else:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS deal_outcomes (
|
||||
|
||||
@@ -17,6 +17,46 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
|
||||
class _CompatResult:
|
||||
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
||||
self._rows = rows or []
|
||||
self.rowcount = rowcount
|
||||
self.lastrowid = lastrowid
|
||||
|
||||
def fetchall(self):
|
||||
return self._rows
|
||||
|
||||
def fetchone(self):
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
|
||||
class _PyMySQLCompatConn:
|
||||
def __init__(self, conn):
|
||||
self._conn = conn
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if exc_type:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
|
||||
def execute(self, query: str, args=None):
|
||||
cur = self._conn.cursor()
|
||||
cur.execute(query, args or ())
|
||||
rows = cur.fetchall() if cur.description else []
|
||||
res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0))
|
||||
cur.close()
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
self._conn.commit()
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
@@ -27,7 +67,7 @@ def _sql(query: str) -> str:
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
if _is_mysql():
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
conn = pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
@@ -37,6 +77,7 @@ def _get_conn() -> sqlite3.Connection:
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
return _PyMySQLCompatConn(conn)
|
||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||
conn = sqlite3.connect(_DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
@@ -90,9 +90,14 @@ class ImageTaskManager:
|
||||
FOREIGN KEY (task_id) REFERENCES image_tasks(task_id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)')
|
||||
cursor.execute("SHOW INDEX FROM image_tasks")
|
||||
exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()}
|
||||
if "idx_customer" not in exists:
|
||||
cursor.execute('CREATE INDEX idx_customer ON image_tasks(customer_id)')
|
||||
if "idx_status" not in exists:
|
||||
cursor.execute('CREATE INDEX idx_status ON image_tasks(status)')
|
||||
if "idx_created" not in exists:
|
||||
cursor.execute('CREATE INDEX idx_created ON image_tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
else:
|
||||
|
||||
@@ -88,9 +88,14 @@ class TaskManager:
|
||||
result TEXT
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
''')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)')
|
||||
cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)')
|
||||
cursor.execute("SHOW INDEX FROM tasks")
|
||||
exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()}
|
||||
if "idx_status" not in exists:
|
||||
cursor.execute('CREATE INDEX idx_status ON tasks(status)')
|
||||
if "idx_customer" not in exists:
|
||||
cursor.execute('CREATE INDEX idx_customer ON tasks(customer_id)')
|
||||
if "idx_created" not in exists:
|
||||
cursor.execute('CREATE INDEX idx_created ON tasks(created_at)')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user