176 lines
5.5 KiB
Python
176 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
把本地 SQLite 聊天记录迁移到 MySQL:
|
|
source: db/chat_log_db/chats.db -> table chat_logs
|
|
|
|
用法示例:
|
|
python scripts/migrate_chat_logs_to_mysql.py --host xinhui.cloud --port 3306 \
|
|
--user ai_cs_user --password xxx --database ai_cs --batch-size 2000 --truncate-target
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pymysql
|
|
|
|
|
|
def ensure_mysql_table(conn):
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS chat_logs (
|
|
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
|
customer_id VARCHAR(128) NOT NULL,
|
|
customer_name VARCHAR(255) DEFAULT '',
|
|
acc_id VARCHAR(128) DEFAULT '',
|
|
platform VARCHAR(64) DEFAULT '',
|
|
direction VARCHAR(8) NOT NULL,
|
|
message TEXT NOT NULL,
|
|
msg_type INTEGER DEFAULT 0,
|
|
timestamp DATETIME NOT NULL
|
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
|
"""
|
|
)
|
|
cur.execute("SHOW INDEX FROM chat_logs")
|
|
exists = {str(r.get("Key_name", "")) for r in cur.fetchall()}
|
|
if "idx_customer" not in exists:
|
|
cur.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)")
|
|
if "idx_ts" not in exists:
|
|
cur.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)")
|
|
if "idx_acc" not in exists:
|
|
cur.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)")
|
|
conn.commit()
|
|
|
|
|
|
def get_sqlite_conn(path: Path):
|
|
conn = sqlite3.connect(str(path))
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
def get_mysql_conn(host: str, port: int, user: str, password: str, database: str):
|
|
return pymysql.connect(
|
|
host=host,
|
|
port=port,
|
|
user=user,
|
|
password=password,
|
|
database=database,
|
|
charset="utf8mb4",
|
|
autocommit=False,
|
|
cursorclass=pymysql.cursors.DictCursor,
|
|
)
|
|
|
|
|
|
def migrate(sqlite_path: Path, host: str, port: int, user: str, password: str, database: str, batch_size: int, truncate_target: bool):
|
|
if not sqlite_path.exists():
|
|
raise FileNotFoundError(f"SQLite 文件不存在: {sqlite_path}")
|
|
|
|
s_conn = get_sqlite_conn(sqlite_path)
|
|
m_conn = get_mysql_conn(host, port, user, password, database)
|
|
try:
|
|
ensure_mysql_table(m_conn)
|
|
if truncate_target:
|
|
with m_conn.cursor() as cur:
|
|
cur.execute("TRUNCATE TABLE chat_logs")
|
|
m_conn.commit()
|
|
|
|
total = s_conn.execute("SELECT COUNT(*) AS c FROM chat_logs").fetchone()["c"]
|
|
print(f"[MIGRATE] SQLite 源总行数: {total}")
|
|
if total == 0:
|
|
return 0
|
|
|
|
migrated = 0
|
|
last_id = 0
|
|
started = time.time()
|
|
|
|
insert_sql = (
|
|
"INSERT INTO chat_logs "
|
|
"(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp) "
|
|
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s)"
|
|
)
|
|
|
|
while True:
|
|
rows = s_conn.execute(
|
|
"""
|
|
SELECT id, customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp
|
|
FROM chat_logs
|
|
WHERE id > ?
|
|
ORDER BY id ASC
|
|
LIMIT ?
|
|
""",
|
|
(last_id, batch_size),
|
|
).fetchall()
|
|
if not rows:
|
|
break
|
|
|
|
vals = []
|
|
for r in rows:
|
|
vals.append(
|
|
(
|
|
r["customer_id"] or "",
|
|
r["customer_name"] or "",
|
|
r["acc_id"] or "",
|
|
r["platform"] or "",
|
|
r["direction"] or "in",
|
|
r["message"] or "",
|
|
int(r["msg_type"] or 0),
|
|
r["timestamp"],
|
|
)
|
|
)
|
|
last_id = r["id"]
|
|
|
|
with m_conn.cursor() as cur:
|
|
cur.executemany(insert_sql, vals)
|
|
m_conn.commit()
|
|
|
|
migrated += len(vals)
|
|
elapsed = time.time() - started
|
|
print(f"[MIGRATE] {migrated}/{total} ({(migrated/total)*100:.1f}%) elapsed={elapsed:.1f}s")
|
|
|
|
return migrated
|
|
finally:
|
|
try:
|
|
s_conn.close()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
m_conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="迁移 chat_logs: SQLite -> MySQL")
|
|
parser.add_argument("--sqlite-path", default=str(Path("db") / "chat_log_db" / "chats.db"))
|
|
parser.add_argument("--host", required=True)
|
|
parser.add_argument("--port", type=int, default=3306)
|
|
parser.add_argument("--user", required=True)
|
|
parser.add_argument("--password", required=True)
|
|
parser.add_argument("--database", required=True)
|
|
parser.add_argument("--batch-size", type=int, default=2000)
|
|
parser.add_argument("--truncate-target", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
sqlite_path = Path(args.sqlite_path)
|
|
migrated = migrate(
|
|
sqlite_path=sqlite_path,
|
|
host=args.host,
|
|
port=args.port,
|
|
user=args.user,
|
|
password=args.password,
|
|
database=args.database,
|
|
batch_size=max(100, int(args.batch_size)),
|
|
truncate_target=bool(args.truncate_target),
|
|
)
|
|
print(f"[DONE] 迁移完成,写入 {migrated} 条")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|