feat: migrate core data stores to MySQL with compatibility fixes
This commit is contained in:
@@ -6,6 +6,17 @@ from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
|
||||
_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower()
|
||||
_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1")
|
||||
_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
|
||||
_MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||
|
||||
|
||||
def _is_mysql() -> bool:
|
||||
return _DB_TYPE in ("mysql", "mariadb")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomerProfile:
|
||||
@@ -163,14 +174,64 @@ class CustomerDatabase:
|
||||
self.db_path = db_path
|
||||
self.customers_file = os.path.join(db_path, "customers.json")
|
||||
self._ensure_db()
|
||||
|
||||
def _get_mysql_conn(self):
|
||||
import pymysql
|
||||
return pymysql.connect(
|
||||
host=_MYSQL_HOST,
|
||||
port=_MYSQL_PORT,
|
||||
user=_MYSQL_USER,
|
||||
password=_MYSQL_PASSWORD,
|
||||
database=_MYSQL_DATABASE,
|
||||
charset="utf8mb4",
|
||||
cursorclass=pymysql.cursors.DictCursor,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
def _ensure_db(self):
|
||||
if _is_mysql():
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS customer_profiles (
|
||||
customer_id VARCHAR(128) PRIMARY KEY,
|
||||
profile_json LONGTEXT NOT NULL,
|
||||
last_update DATETIME NOT NULL
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""
|
||||
)
|
||||
cur.execute("SHOW INDEX FROM customer_profiles")
|
||||
exists = {str(r.get("Key_name", "")) for r in cur.fetchall()}
|
||||
if "idx_last_update" not in exists:
|
||||
cur.execute("CREATE INDEX idx_last_update ON customer_profiles(last_update)")
|
||||
conn.commit()
|
||||
return
|
||||
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
if not os.path.exists(self.customers_file):
|
||||
self._save_customers({})
|
||||
|
||||
def _load_customers(self) -> Dict[str, dict]:
|
||||
if _is_mysql():
|
||||
out: Dict[str, dict] = {}
|
||||
try:
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT customer_id, profile_json FROM customer_profiles")
|
||||
rows = cur.fetchall()
|
||||
for r in rows:
|
||||
cid = str(r.get("customer_id") or "")
|
||||
if not cid:
|
||||
continue
|
||||
try:
|
||||
out[cid] = json.loads(r.get("profile_json") or "{}")
|
||||
except Exception:
|
||||
out[cid] = {}
|
||||
except Exception:
|
||||
return {}
|
||||
return out
|
||||
try:
|
||||
with open(self.customers_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
@@ -178,10 +239,41 @@ class CustomerDatabase:
|
||||
return {}
|
||||
|
||||
def _save_customers(self, customers: Dict[str, dict]):
|
||||
if _is_mysql():
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
for cid, data in (customers or {}).items():
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(cid, json.dumps(data, ensure_ascii=False), now),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
with open(self.customers_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(customers, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def get_customer(self, customer_id: str) -> CustomerProfile:
|
||||
if _is_mysql():
|
||||
try:
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT profile_json FROM customer_profiles WHERE customer_id=%s LIMIT 1",
|
||||
(customer_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row.get("profile_json"):
|
||||
data = json.loads(row.get("profile_json") or "{}")
|
||||
else:
|
||||
data = {}
|
||||
data.pop('customer_id', None)
|
||||
return CustomerProfile(customer_id=customer_id, **data)
|
||||
except Exception:
|
||||
return CustomerProfile(customer_id=customer_id)
|
||||
customers = self._load_customers()
|
||||
data = customers.get(customer_id, {})
|
||||
# 确保不重复传递 customer_id
|
||||
@@ -190,6 +282,22 @@ class CustomerDatabase:
|
||||
|
||||
def save_customer(self, profile: CustomerProfile):
|
||||
profile.last_update = datetime.now().isoformat()
|
||||
if _is_mysql():
|
||||
with self._get_mysql_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||
VALUES (%s, %s, %s)
|
||||
""",
|
||||
(
|
||||
profile.customer_id,
|
||||
json.dumps(asdict(profile), ensure_ascii=False),
|
||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return
|
||||
customers = self._load_customers()
|
||||
customers[profile.customer_id] = asdict(profile)
|
||||
self._save_customers(customers)
|
||||
|
||||
Reference in New Issue
Block a user