diff --git a/db/chat_log_db.py b/db/chat_log_db.py index d900566..e967181 100755 --- a/db/chat_log_db.py +++ b/db/chat_log_db.py @@ -16,6 +16,51 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") + +class _CompatResult: + def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0): + self._rows = rows or [] + self.rowcount = rowcount + self.lastrowid = lastrowid + + def fetchall(self): + return self._rows + + def fetchone(self): + return self._rows[0] if self._rows else None + + +class _PyMySQLCompatConn: + """让 pymysql 连接兼容 sqlite 的 conn.execute 用法。""" + + def __init__(self, conn): + self._conn = conn + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + if exc_type: + try: + self._conn.rollback() + except Exception: + pass + self._conn.close() + + def execute(self, query: str, args=None): + cur = self._conn.cursor() + cur.execute(query, args or ()) + rows = cur.fetchall() if cur.description else [] + res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0)) + cur.close() + return res + + def commit(self): + self._conn.commit() + + def close(self): + self._conn.close() + def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") @@ -26,7 +71,7 @@ def _sql(query: str) -> str: def _get_conn() -> sqlite3.Connection: if _is_mysql(): import pymysql - return pymysql.connect( + conn = pymysql.connect( host=_MYSQL_HOST, port=_MYSQL_PORT, user=_MYSQL_USER, @@ -36,6 +81,7 @@ def _get_conn() -> sqlite3.Connection: cursorclass=pymysql.cursors.DictCursor, autocommit=False, ) + return _PyMySQLCompatConn(conn) os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True) conn = sqlite3.connect(_DB_PATH) conn.row_factory = sqlite3.Row @@ -59,9 +105,14 @@ def init_db(): timestamp DATETIME NOT NULL ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)") + idx_rows = conn.execute("SHOW INDEX FROM chat_logs").fetchall() + exists = {str(r.get("Key_name", "")) for r in idx_rows} + if "idx_customer" not in exists: + conn.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)") + if "idx_ts" not in exists: + conn.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)") + if "idx_acc" not in exists: + conn.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)") else: conn.execute(""" CREATE TABLE IF NOT EXISTS chat_logs ( diff --git a/db/customer_db.py b/db/customer_db.py index b3e8725..9f80065 100755 --- a/db/customer_db.py +++ b/db/customer_db.py @@ -6,6 +6,17 @@ from typing import Optional, Dict, Any, List from dataclasses import dataclass, asdict from collections import defaultdict +_DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() +_MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1") +_MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306")) +_MYSQL_USER = os.getenv("MYSQL_USER", "root") +_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") +_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") + + +def _is_mysql() -> bool: + return _DB_TYPE in ("mysql", "mariadb") + @dataclass class CustomerProfile: @@ -163,14 +174,64 @@ class CustomerDatabase: self.db_path = db_path self.customers_file = os.path.join(db_path, "customers.json") self._ensure_db() + + def _get_mysql_conn(self): + import pymysql + return pymysql.connect( + host=_MYSQL_HOST, + port=_MYSQL_PORT, + user=_MYSQL_USER, + password=_MYSQL_PASSWORD, + database=_MYSQL_DATABASE, + charset="utf8mb4", + cursorclass=pymysql.cursors.DictCursor, + autocommit=False, + ) def _ensure_db(self): + if _is_mysql(): + with self._get_mysql_conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS customer_profiles ( + customer_id VARCHAR(128) PRIMARY KEY, + profile_json LONGTEXT NOT NULL, + last_update DATETIME NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """ + ) + cur.execute("SHOW INDEX FROM customer_profiles") + exists = {str(r.get("Key_name", "")) for r in cur.fetchall()} + if "idx_last_update" not in exists: + cur.execute("CREATE INDEX idx_last_update ON customer_profiles(last_update)") + conn.commit() + return + if not os.path.exists(self.db_path): os.makedirs(self.db_path) if not os.path.exists(self.customers_file): self._save_customers({}) def _load_customers(self) -> Dict[str, dict]: + if _is_mysql(): + out: Dict[str, dict] = {} + try: + with self._get_mysql_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT customer_id, profile_json FROM customer_profiles") + rows = cur.fetchall() + for r in rows: + cid = str(r.get("customer_id") or "") + if not cid: + continue + try: + out[cid] = json.loads(r.get("profile_json") or "{}") + except Exception: + out[cid] = {} + except Exception: + return {} + return out try: with open(self.customers_file, 'r', encoding='utf-8') as f: return json.load(f) @@ -178,10 +239,41 @@ class CustomerDatabase: return {} def _save_customers(self, customers: Dict[str, dict]): + if _is_mysql(): + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + with self._get_mysql_conn() as conn: + with conn.cursor() as cur: + for cid, data in (customers or {}).items(): + cur.execute( + """ + REPLACE INTO customer_profiles (customer_id, profile_json, last_update) + VALUES (%s, %s, %s) + """, + (cid, json.dumps(data, ensure_ascii=False), now), + ) + conn.commit() + return with open(self.customers_file, 'w', encoding='utf-8') as f: json.dump(customers, f, ensure_ascii=False, indent=2) def get_customer(self, customer_id: str) -> CustomerProfile: + if _is_mysql(): + try: + with self._get_mysql_conn() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT profile_json FROM customer_profiles WHERE customer_id=%s LIMIT 1", + (customer_id,), + ) + row = cur.fetchone() + if row and row.get("profile_json"): + data = json.loads(row.get("profile_json") or "{}") + else: + data = {} + data.pop('customer_id', None) + return CustomerProfile(customer_id=customer_id, **data) + except Exception: + return CustomerProfile(customer_id=customer_id) customers = self._load_customers() data = customers.get(customer_id, {}) # 确保不重复传递 customer_id @@ -190,6 +282,22 @@ class CustomerDatabase: def save_customer(self, profile: CustomerProfile): profile.last_update = datetime.now().isoformat() + if _is_mysql(): + with self._get_mysql_conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + REPLACE INTO customer_profiles (customer_id, profile_json, last_update) + VALUES (%s, %s, %s) + """, + ( + profile.customer_id, + json.dumps(asdict(profile), ensure_ascii=False), + datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + ), + ) + conn.commit() + return customers = self._load_customers() customers[profile.customer_id] = asdict(profile) self._save_customers(customers) diff --git a/db/deal_outcome_db.py b/db/deal_outcome_db.py index 8c377b1..5204a00 100755 --- a/db/deal_outcome_db.py +++ b/db/deal_outcome_db.py @@ -15,6 +15,46 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") + +class _CompatResult: + def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0): + self._rows = rows or [] + self.rowcount = rowcount + self.lastrowid = lastrowid + + def fetchall(self): + return self._rows + + def fetchone(self): + return self._rows[0] if self._rows else None + + +class _PyMySQLCompatConn: + def __init__(self, conn): + self._conn = conn + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + if exc_type: + try: + self._conn.rollback() + except Exception: + pass + self._conn.close() + + def execute(self, query: str, args=None): + cur = self._conn.cursor() + cur.execute(query, args or ()) + rows = cur.fetchall() if cur.description else [] + res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0)) + cur.close() + return res + + def commit(self): + self._conn.commit() + def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") @@ -25,7 +65,7 @@ def _sql(query: str) -> str: def _get_conn() -> sqlite3.Connection: if _is_mysql(): import pymysql - return pymysql.connect( + conn = pymysql.connect( host=_MYSQL_HOST, port=_MYSQL_PORT, user=_MYSQL_USER, @@ -35,6 +75,7 @@ def _get_conn() -> sqlite3.Connection: cursorclass=pymysql.cursors.DictCursor, autocommit=False, ) + return _PyMySQLCompatConn(conn) os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True) conn = sqlite3.connect(_DB_PATH) conn.row_factory = sqlite3.Row @@ -60,10 +101,16 @@ def _init_db(): timestamp DATETIME NOT NULL ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_date ON deal_outcomes(date)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_customer ON deal_outcomes(customer_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_acc ON deal_outcomes(acc_id)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_deal_outcome ON deal_outcomes(outcome)") + idx_rows = conn.execute("SHOW INDEX FROM deal_outcomes").fetchall() + exists = {str(r.get("Key_name", "")) for r in idx_rows} + if "idx_deal_date" not in exists: + conn.execute("CREATE INDEX idx_deal_date ON deal_outcomes(date)") + if "idx_deal_customer" not in exists: + conn.execute("CREATE INDEX idx_deal_customer ON deal_outcomes(customer_id)") + if "idx_deal_acc" not in exists: + conn.execute("CREATE INDEX idx_deal_acc ON deal_outcomes(acc_id)") + if "idx_deal_outcome" not in exists: + conn.execute("CREATE INDEX idx_deal_outcome ON deal_outcomes(outcome)") else: conn.execute(""" CREATE TABLE IF NOT EXISTS deal_outcomes ( diff --git a/db/designer_roster_db.py b/db/designer_roster_db.py index 8eab246..8f5c2dd 100755 --- a/db/designer_roster_db.py +++ b/db/designer_roster_db.py @@ -17,6 +17,46 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") + +class _CompatResult: + def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0): + self._rows = rows or [] + self.rowcount = rowcount + self.lastrowid = lastrowid + + def fetchall(self): + return self._rows + + def fetchone(self): + return self._rows[0] if self._rows else None + + +class _PyMySQLCompatConn: + def __init__(self, conn): + self._conn = conn + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + if exc_type: + try: + self._conn.rollback() + except Exception: + pass + self._conn.close() + + def execute(self, query: str, args=None): + cur = self._conn.cursor() + cur.execute(query, args or ()) + rows = cur.fetchall() if cur.description else [] + res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0)) + cur.close() + return res + + def commit(self): + self._conn.commit() + def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") @@ -27,7 +67,7 @@ def _sql(query: str) -> str: def _get_conn() -> sqlite3.Connection: if _is_mysql(): import pymysql - return pymysql.connect( + conn = pymysql.connect( host=_MYSQL_HOST, port=_MYSQL_PORT, user=_MYSQL_USER, @@ -37,6 +77,7 @@ def _get_conn() -> sqlite3.Connection: cursorclass=pymysql.cursors.DictCursor, autocommit=False, ) + return _PyMySQLCompatConn(conn) os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True) conn = sqlite3.connect(_DB_PATH) conn.row_factory = sqlite3.Row diff --git a/db/image_tasks_db.py b/db/image_tasks_db.py index 2b1fbe9..173419b 100644 --- a/db/image_tasks_db.py +++ b/db/image_tasks_db.py @@ -90,9 +90,14 @@ class ImageTaskManager: FOREIGN KEY (task_id) REFERENCES image_tasks(task_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ''') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON image_tasks(customer_id)') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON image_tasks(status)') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON image_tasks(created_at)') + cursor.execute("SHOW INDEX FROM image_tasks") + exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()} + if "idx_customer" not in exists: + cursor.execute('CREATE INDEX idx_customer ON image_tasks(customer_id)') + if "idx_status" not in exists: + cursor.execute('CREATE INDEX idx_status ON image_tasks(status)') + if "idx_created" not in exists: + cursor.execute('CREATE INDEX idx_created ON image_tasks(created_at)') conn.commit() conn.close() else: diff --git a/db/task_db/task_model.py b/db/task_db/task_model.py index 84c4a49..5062297 100644 --- a/db/task_db/task_model.py +++ b/db/task_db/task_model.py @@ -88,9 +88,14 @@ class TaskManager: result TEXT ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 ''') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_customer ON tasks(customer_id)') - cursor.execute('CREATE INDEX IF NOT EXISTS idx_created ON tasks(created_at)') + cursor.execute("SHOW INDEX FROM tasks") + exists = {str(r.get("Key_name", "")) for r in cursor.fetchall()} + if "idx_status" not in exists: + cursor.execute('CREATE INDEX idx_status ON tasks(status)') + if "idx_customer" not in exists: + cursor.execute('CREATE INDEX idx_customer ON tasks(customer_id)') + if "idx_created" not in exists: + cursor.execute('CREATE INDEX idx_created ON tasks(created_at)') conn.commit() conn.close() else: diff --git a/scripts/migrate_chat_logs_to_mysql.py b/scripts/migrate_chat_logs_to_mysql.py new file mode 100644 index 0000000..57905e4 --- /dev/null +++ b/scripts/migrate_chat_logs_to_mysql.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +把本地 SQLite 聊天记录迁移到 MySQL: + source: db/chat_log_db/chats.db -> table chat_logs + +用法示例: + python scripts/migrate_chat_logs_to_mysql.py --host xinhui.cloud --port 3306 \ + --user ai_cs_user --password xxx --database ai_cs --batch-size 2000 --truncate-target +""" + +from __future__ import annotations + +import argparse +import os +import sqlite3 +import time +from pathlib import Path + +import pymysql + + +def ensure_mysql_table(conn): + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS chat_logs ( + id INTEGER PRIMARY KEY AUTO_INCREMENT, + customer_id VARCHAR(128) NOT NULL, + customer_name VARCHAR(255) DEFAULT '', + acc_id VARCHAR(128) DEFAULT '', + platform VARCHAR(64) DEFAULT '', + direction VARCHAR(8) NOT NULL, + message TEXT NOT NULL, + msg_type INTEGER DEFAULT 0, + timestamp DATETIME NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """ + ) + cur.execute("SHOW INDEX FROM chat_logs") + exists = {str(r.get("Key_name", "")) for r in cur.fetchall()} + if "idx_customer" not in exists: + cur.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)") + if "idx_ts" not in exists: + cur.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)") + if "idx_acc" not in exists: + cur.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)") + conn.commit() + + +def get_sqlite_conn(path: Path): + conn = sqlite3.connect(str(path)) + conn.row_factory = sqlite3.Row + return conn + + +def get_mysql_conn(host: str, port: int, user: str, password: str, database: str): + return pymysql.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + charset="utf8mb4", + autocommit=False, + cursorclass=pymysql.cursors.DictCursor, + ) + + +def migrate(sqlite_path: Path, host: str, port: int, user: str, password: str, database: str, batch_size: int, truncate_target: bool): + if not sqlite_path.exists(): + raise FileNotFoundError(f"SQLite 文件不存在: {sqlite_path}") + + s_conn = get_sqlite_conn(sqlite_path) + m_conn = get_mysql_conn(host, port, user, password, database) + try: + ensure_mysql_table(m_conn) + if truncate_target: + with m_conn.cursor() as cur: + cur.execute("TRUNCATE TABLE chat_logs") + m_conn.commit() + + total = s_conn.execute("SELECT COUNT(*) AS c FROM chat_logs").fetchone()["c"] + print(f"[MIGRATE] SQLite 源总行数: {total}") + if total == 0: + return 0 + + migrated = 0 + last_id = 0 + started = time.time() + + insert_sql = ( + "INSERT INTO chat_logs " + "(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) " + "VALUES (%s,%s,%s,%s,%s,%s,%s,%s)" + ) + + while True: + rows = s_conn.execute( + """ + SELECT id, customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp + FROM chat_logs + WHERE id > ? + ORDER BY id ASC + LIMIT ? + """, + (last_id, batch_size), + ).fetchall() + if not rows: + break + + vals = [] + for r in rows: + vals.append( + ( + r["customer_id"] or "", + r["customer_name"] or "", + r["acc_id"] or "", + r["platform"] or "", + r["direction"] or "in", + r["message"] or "", + int(r["msg_type"] or 0), + r["timestamp"], + ) + ) + last_id = r["id"] + + with m_conn.cursor() as cur: + cur.executemany(insert_sql, vals) + m_conn.commit() + + migrated += len(vals) + elapsed = time.time() - started + print(f"[MIGRATE] {migrated}/{total} ({(migrated/total)*100:.1f}%) elapsed={elapsed:.1f}s") + + return migrated + finally: + try: + s_conn.close() + except Exception: + pass + try: + m_conn.close() + except Exception: + pass + + +def main(): + parser = argparse.ArgumentParser(description="迁移 chat_logs: SQLite -> MySQL") + parser.add_argument("--sqlite-path", default=str(Path("db") / "chat_log_db" / "chats.db")) + parser.add_argument("--host", required=True) + parser.add_argument("--port", type=int, default=3306) + parser.add_argument("--user", required=True) + parser.add_argument("--password", required=True) + parser.add_argument("--database", required=True) + parser.add_argument("--batch-size", type=int, default=2000) + parser.add_argument("--truncate-target", action="store_true") + args = parser.parse_args() + + sqlite_path = Path(args.sqlite_path) + migrated = migrate( + sqlite_path=sqlite_path, + host=args.host, + port=args.port, + user=args.user, + password=args.password, + database=args.database, + batch_size=max(100, int(args.batch_size)), + truncate_target=bool(args.truncate_target), + ) + print(f"[DONE] 迁移完成,写入 {migrated} 条") + + +if __name__ == "__main__": + main() diff --git a/scripts/migrate_customers_json_to_mysql.py b/scripts/migrate_customers_json_to_mysql.py new file mode 100644 index 0000000..268ca8c --- /dev/null +++ b/scripts/migrate_customers_json_to_mysql.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +迁移 customer_db/customers.json -> MySQL customer_profiles +""" + +from __future__ import annotations + +import argparse +import json +from datetime import datetime +from pathlib import Path + +import pymysql + + +def get_conn(host: str, port: int, user: str, password: str, database: str): + return pymysql.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + charset="utf8mb4", + autocommit=False, + cursorclass=pymysql.cursors.DictCursor, + ) + + +def ensure_table(conn): + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE IF NOT EXISTS customer_profiles ( + customer_id VARCHAR(128) PRIMARY KEY, + profile_json LONGTEXT NOT NULL, + last_update DATETIME NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 + """ + ) + cur.execute("SHOW INDEX FROM customer_profiles") + exists = {str(r.get("Key_name", "")) for r in cur.fetchall()} + if "idx_last_update" not in exists: + cur.execute("CREATE INDEX idx_last_update ON customer_profiles(last_update)") + conn.commit() + + +def migrate(json_path: Path, host: str, port: int, user: str, password: str, database: str, truncate_target: bool): + if not json_path.exists(): + raise FileNotFoundError(f"customers.json 不存在: {json_path}") + customers = json.loads(json_path.read_text(encoding="utf-8") or "{}") + if not isinstance(customers, dict): + raise RuntimeError("customers.json 格式错误,期望对象映射") + + conn = get_conn(host, port, user, password, database) + try: + ensure_table(conn) + if truncate_target: + with conn.cursor() as cur: + cur.execute("TRUNCATE TABLE customer_profiles") + conn.commit() + + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + sql = ( + "REPLACE INTO customer_profiles (customer_id, profile_json, last_update) " + "VALUES (%s, %s, %s)" + ) + total = 0 + with conn.cursor() as cur: + for cid, profile in customers.items(): + cur.execute(sql, (str(cid), json.dumps(profile, ensure_ascii=False), now)) + total += 1 + conn.commit() + return total + finally: + conn.close() + + +def main(): + parser = argparse.ArgumentParser(description="迁移 customers.json 到 MySQL") + parser.add_argument("--json-path", default=str(Path("customer_db") / "customers.json")) + parser.add_argument("--host", required=True) + parser.add_argument("--port", type=int, default=3306) + parser.add_argument("--user", required=True) + parser.add_argument("--password", required=True) + parser.add_argument("--database", required=True) + parser.add_argument("--truncate-target", action="store_true") + args = parser.parse_args() + + total = migrate( + json_path=Path(args.json_path), + host=args.host, + port=args.port, + user=args.user, + password=args.password, + database=args.database, + truncate_target=bool(args.truncate_target), + ) + print(f"[DONE] customer_profiles 写入 {total} 条") + + +if __name__ == "__main__": + main() diff --git a/scripts/migrate_remaining_sqlite_to_mysql.py b/scripts/migrate_remaining_sqlite_to_mysql.py new file mode 100644 index 0000000..98f49bb --- /dev/null +++ b/scripts/migrate_remaining_sqlite_to_mysql.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +迁移其余 SQLite 业务库到 MySQL(保留主键): +- deal_outcome_db/outcomes.db -> deal_outcomes +- designer_roster_db/roster.db -> designers/designer_shops/designer_online/round_robin +- image_tasks.db -> image_tasks/requirement_history +- task_db/tasks.db -> tasks/task_logs +""" + +from __future__ import annotations + +import argparse +import sqlite3 +from pathlib import Path +from typing import List, Dict + +import pymysql + + +MAPPINGS = [ + {"sqlite": Path("db/deal_outcome_db/outcomes.db"), "tables": ["deal_outcomes"]}, + {"sqlite": Path("db/designer_roster_db/roster.db"), "tables": ["designers", "designer_shops", "designer_online", "round_robin"]}, + {"sqlite": Path("db/image_tasks.db"), "tables": ["image_tasks", "task_requirement_changes"]}, + {"sqlite": Path("db/task_db/tasks.db"), "tables": ["tasks"]}, +] + + +def mysql_conn(host: str, port: int, user: str, password: str, database: str): + return pymysql.connect( + host=host, + port=port, + user=user, + password=password, + database=database, + charset="utf8mb4", + autocommit=False, + cursorclass=pymysql.cursors.DictCursor, + ) + + +def sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: + row = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table,), + ).fetchone() + return row is not None + + +def sqlite_fetch_all(conn: sqlite3.Connection, table: str) -> List[sqlite3.Row]: + conn.row_factory = sqlite3.Row + return conn.execute(f"SELECT * FROM {table}").fetchall() + + +def migrate_table(mysql, rows: List[sqlite3.Row], table: str, truncate_target: bool) -> int: + if not rows: + return 0 + cols = list(rows[0].keys()) + col_sql = ", ".join(cols) + val_sql = ", ".join(["%s"] * len(cols)) + sql = f"REPLACE INTO {table} ({col_sql}) VALUES ({val_sql})" + if truncate_target: + with mysql.cursor() as cur: + try: + cur.execute(f"TRUNCATE TABLE {table}") + except Exception: + try: + cur.execute(f"DELETE FROM {table}") + except Exception: + return 0 + values = [tuple(r[c] for c in cols) for r in rows] + with mysql.cursor() as cur: + cur.executemany(sql, values) + mysql.commit() + return len(values) + + +def main(): + p = argparse.ArgumentParser(description="迁移剩余 SQLite 业务库到 MySQL") + p.add_argument("--host", required=True) + p.add_argument("--port", type=int, default=3306) + p.add_argument("--user", required=True) + p.add_argument("--password", required=True) + p.add_argument("--database", required=True) + p.add_argument("--truncate-target", action="store_true") + args = p.parse_args() + + total = 0 + with mysql_conn(args.host, args.port, args.user, args.password, args.database) as mconn: + for item in MAPPINGS: + sp = item["sqlite"] + if not sp.exists(): + continue + sconn = sqlite3.connect(str(sp)) + try: + for table in item["tables"]: + if not sqlite_table_exists(sconn, table): + continue + rows = sqlite_fetch_all(sconn, table) + n = migrate_table(mconn, rows, table, truncate_target=bool(args.truncate_target)) + total += n + print(f"[MIGRATE] {sp}::{table} -> {n}") + finally: + sconn.close() + print(f"[DONE] migrated total rows: {total}") + + +if __name__ == "__main__": + main()