""" 聊天记录数据库(SQLite / MySQL) 每条消息独立存储,按客户ID分开,支持查询和展示。 支持 MySQL 连接池以提高性能。 """ import sqlite3 import os import threading from queue import Queue, Empty from datetime import datetime from typing import List, Dict, Optional _DB_PATH = os.path.join(os.path.dirname(__file__), "chat_log_db", "chats.db") _DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() _MYSQL_HOST = os.getenv("MYSQL_HOST", "127.0.0.1") _MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306")) _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") # ========== MySQL 连接池 ========== _POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "10")) _POOL_WAIT_TIMEOUT = float(os.getenv("MYSQL_POOL_WAIT_TIMEOUT", "10")) _mysql_pool: Optional[Queue] = None _pool_lock = threading.Lock() _mysql_conn_count = 0 def _create_mysql_conn(): """创建单个 MySQL 连接""" import pymysql return pymysql.connect( host=_MYSQL_HOST, port=_MYSQL_PORT, user=_MYSQL_USER, password=_MYSQL_PASSWORD, database=_MYSQL_DATABASE, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor, autocommit=False, connect_timeout=10, read_timeout=30, write_timeout=30, ) def _init_mysql_pool(): """初始化 MySQL 连接池(懒创建,不在启动时预建满)""" global _mysql_pool with _pool_lock: if _mysql_pool is None: _mysql_pool = Queue(maxsize=_POOL_SIZE) def _discard_conn(conn): """丢弃失效连接并维护计数""" global _mysql_conn_count try: conn.close() except Exception: pass with _pool_lock: if _mysql_conn_count > 0: _mysql_conn_count -= 1 def _get_pooled_conn(timeout: float = 5.0): """从连接池获取连接,达到上限后阻塞等待,不再额外扩容。""" global _mysql_pool, _mysql_conn_count if _mysql_pool is None: _init_mysql_pool() with _pool_lock: if _mysql_conn_count < _POOL_SIZE: conn = _create_mysql_conn() _mysql_conn_count += 1 return conn try: conn = _mysql_pool.get(timeout=timeout) try: conn.ping(reconnect=True) except Exception: _discard_conn(conn) with _pool_lock: if _mysql_conn_count < _POOL_SIZE: conn = _create_mysql_conn() _mysql_conn_count += 1 return conn conn = _mysql_pool.get(timeout=timeout) conn.ping(reconnect=True) return conn except Empty: raise TimeoutError(f"MySQL连接池已耗尽(pool_size={_POOL_SIZE}, wait_timeout={timeout}s)") def _return_conn(conn): """归还连接到池,失效连接直接丢弃。""" global _mysql_pool if _mysql_pool is None: return try: conn.ping(reconnect=False) _mysql_pool.put_nowait(conn) except Exception: _discard_conn(conn) class _CompatResult: def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0): self._rows = rows or [] self.rowcount = rowcount self.lastrowid = lastrowid def fetchall(self): return self._rows def fetchone(self): return self._rows[0] if self._rows else None class _PyMySQLCompatConn: """让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池。""" def __init__(self, conn, use_pool: bool = True): self._conn = conn self._use_pool = use_pool def __enter__(self): return self def __exit__(self, exc_type, exc, tb): if exc_type: try: self._conn.rollback() except Exception: pass # 归还连接到池而不是关闭 if self._use_pool: _return_conn(self._conn) else: self._conn.close() def execute(self, query: str, args=None): cur = self._conn.cursor() cur.execute(query, args or ()) rows = cur.fetchall() if cur.description else [] res = _CompatResult(rows=rows, rowcount=cur.rowcount, lastrowid=getattr(cur, "lastrowid", 0)) cur.close() return res def commit(self): self._conn.commit() def close(self): if self._use_pool: _return_conn(self._conn) else: self._conn.close() def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") def _sql(query: str) -> str: return query.replace("?", "%s") if _is_mysql() else query def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection: """获取数据库连接,MySQL 使用连接池""" if _is_mysql(): import time last_error = None for attempt in range(max_retries): try: conn = _get_pooled_conn(timeout=_POOL_WAIT_TIMEOUT) return _PyMySQLCompatConn(conn, use_pool=True) except Exception as e: last_error = e if attempt < max_retries - 1: time.sleep(retry_delay * (attempt + 1)) continue raise raise last_error os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True) conn = sqlite3.connect(_DB_PATH) conn.row_factory = sqlite3.Row return conn def init_db(): """建表(首次运行时自动调用)""" with _get_conn() as conn: if _is_mysql(): conn.execute(""" CREATE TABLE IF NOT EXISTS chat_logs ( id INTEGER PRIMARY KEY AUTO_INCREMENT, customer_id VARCHAR(128) NOT NULL, customer_name VARCHAR(255) DEFAULT '', acc_id VARCHAR(128) DEFAULT '', platform VARCHAR(64) DEFAULT '', direction VARCHAR(8) NOT NULL, message TEXT NOT NULL, msg_type INTEGER DEFAULT 0, timestamp DATETIME NOT NULL ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """) idx_rows = conn.execute("SHOW INDEX FROM chat_logs").fetchall() exists = {str(r.get("Key_name", "")) for r in idx_rows} if "idx_customer" not in exists: conn.execute("CREATE INDEX idx_customer ON chat_logs(customer_id)") if "idx_ts" not in exists: conn.execute("CREATE INDEX idx_ts ON chat_logs(timestamp)") if "idx_acc" not in exists: conn.execute("CREATE INDEX idx_acc ON chat_logs(acc_id)") # 添加 image_urls 列(如果不存在) try: conn.execute("ALTER TABLE chat_logs ADD COLUMN image_urls TEXT DEFAULT ''") except Exception: pass # 列已存在 else: conn.execute(""" CREATE TABLE IF NOT EXISTS chat_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, customer_id TEXT NOT NULL, customer_name TEXT DEFAULT '', acc_id TEXT DEFAULT '', platform TEXT DEFAULT '', direction TEXT NOT NULL CHECK(direction IN ('in','out')), message TEXT NOT NULL, msg_type INTEGER DEFAULT 0, timestamp TEXT NOT NULL ) """) conn.execute("CREATE INDEX IF NOT EXISTS idx_customer ON chat_logs(customer_id)") conn.execute("CREATE INDEX IF NOT EXISTS idx_ts ON chat_logs(timestamp)") try: conn.execute("ALTER TABLE chat_logs ADD COLUMN acc_id TEXT DEFAULT ''") except Exception: pass try: conn.execute("ALTER TABLE chat_logs ADD COLUMN image_urls TEXT DEFAULT ''") except Exception: pass conn.execute("CREATE INDEX IF NOT EXISTS idx_acc ON chat_logs(acc_id)") # ---- customer_orders 表 ---- if _is_mysql(): conn.execute(""" CREATE TABLE IF NOT EXISTS customer_orders ( id INTEGER PRIMARY KEY AUTO_INCREMENT, customer_id VARCHAR(128) NOT NULL, acc_id VARCHAR(128) DEFAULT '', order_id VARCHAR(64) NOT NULL, order_status VARCHAR(64) DEFAULT '', product_title VARCHAR(512) DEFAULT '', amount DECIMAL(10,2) DEFAULT 0, quantity INTEGER DEFAULT 0, buyer_note TEXT, created_at DATETIME NOT NULL, updated_at DATETIME NOT NULL ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 """) idx_rows2 = conn.execute("SHOW INDEX FROM customer_orders").fetchall() exists2 = {str(r.get("Key_name", "")) for r in idx_rows2} if "idx_co_customer" not in exists2: conn.execute("CREATE INDEX idx_co_customer ON customer_orders(customer_id)") if "idx_co_order" not in exists2: conn.execute("CREATE UNIQUE INDEX idx_co_order ON customer_orders(order_id, order_status)") else: conn.execute(""" CREATE TABLE IF NOT EXISTS customer_orders ( id INTEGER PRIMARY KEY AUTOINCREMENT, customer_id TEXT NOT NULL, acc_id TEXT DEFAULT '', order_id TEXT NOT NULL, order_status TEXT DEFAULT '', product_title TEXT DEFAULT '', amount REAL DEFAULT 0, quantity INTEGER DEFAULT 0, buyer_note TEXT DEFAULT '', created_at TEXT NOT NULL, updated_at TEXT NOT NULL ) """) conn.execute("CREATE INDEX IF NOT EXISTS idx_co_customer ON customer_orders(customer_id)") conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_co_order ON customer_orders(order_id, order_status)") conn.commit() init_db() # ========== 重试装饰器 ========== def _retry_db_operation(func): """数据库操作重试装饰器,处理连接丢失等临时错误""" import functools import time @functools.wraps(func) def wrapper(*args, **kwargs): max_retries = 3 last_error = None for attempt in range(max_retries): try: return func(*args, **kwargs) except Exception as e: err_str = str(e).lower() # 判断是否为可重试的连接错误 is_conn_error = any(k in err_str for k in [ "lost connection", "gone away", "connection reset", "can't connect", "connection refused", "2013", "2006" ]) if is_conn_error and attempt < max_retries - 1: last_error = e time.sleep(0.5 * (attempt + 1)) continue raise raise last_error return wrapper # ========== 写入 ========== @_retry_db_operation def log_message( customer_id: str, message: str, direction: str, # "in" = 客户发来,"out" = 客服回复 customer_name: str = "", acc_id: str = "", # 店铺账号ID platform: str = "", msg_type: int = 0, image_urls: str = "", # 图片URL列表,用\n分隔 ): """记录一条聊天消息""" ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with _get_conn() as conn: conn.execute( _sql("INSERT INTO chat_logs " "(customer_id, customer_name, acc_id, platform, direction, message, msg_type, timestamp, image_urls) " "VALUES (?,?,?,?,?,?,?,?,?)"), (customer_id, customer_name, acc_id, platform, direction, message, msg_type, ts, image_urls), ) conn.commit() # ========== 查询 ========== def get_customers(limit: int = 100) -> List[Dict]: """返回所有有记录的客户列表(按最新消息时间排序)""" with _get_conn() as conn: rows = conn.execute(""" SELECT customer_id, MAX(customer_name) AS customer_name, MAX(platform) AS platform, COUNT(*) AS total_msgs, SUM(direction='in') AS recv, SUM(direction='out') AS sent, MAX(timestamp) AS last_time FROM chat_logs GROUP BY customer_id ORDER BY last_time DESC LIMIT %s """ if _is_mysql() else """ SELECT customer_id, MAX(customer_name) AS customer_name, MAX(platform) AS platform, COUNT(*) AS total_msgs, SUM(direction='in') AS recv, SUM(direction='out') AS sent, MAX(timestamp) AS last_time FROM chat_logs GROUP BY customer_id ORDER BY last_time DESC LIMIT ? """, (limit,)).fetchall() return [dict(r) for r in rows] @_retry_db_operation def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]: """返回某客户的最近对话记录(按时间升序)""" # 忽略 acc_id 过滤,实现全店铺记忆 with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT * FROM ( SELECT id, direction, message, msg_type, timestamp, acc_id FROM chat_logs WHERE customer_id = ? ORDER BY timestamp DESC, id DESC LIMIT ? ) AS recent ORDER BY timestamp ASC, id ASC """), (customer_id, limit)).fetchall() return [dict(r) for r in rows] def get_recent_conversation(customer_id: str, acc_id: str = "", limit: int = 10) -> List[Dict]: """返回某客户近期对话,忽略 acc_id 过滤""" with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT id, direction, message, timestamp, acc_id FROM chat_logs WHERE customer_id = ? ORDER BY id DESC LIMIT ? """), (customer_id, limit)).fetchall() out = [dict(r) for r in reversed(rows)] return out def get_conversation_today(customer_id: str) -> List[Dict]: """返回某客户今天的对话""" today = datetime.now().strftime("%Y-%m-%d") with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT id, direction, message, msg_type, timestamp FROM chat_logs WHERE customer_id = ? AND timestamp LIKE ? ORDER BY timestamp ASC, id ASC """), (customer_id, f"{today}%")).fetchall() return [dict(r) for r in rows] def get_daily_stats(date: str = "") -> List[Dict]: """ 返回指定日期各店铺的统计数据。 date 格式 'YYYY-MM-DD',默认今天。 每条记录对应一个 acc_id(店铺)。 """ if not date: date = datetime.now().strftime("%Y-%m-%d") with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT acc_id, platform, COUNT(DISTINCT customer_id) AS unique_customers, COUNT(*) AS total_msgs, SUM(direction='in') AS recv, SUM(direction='out') AS sent, MIN(timestamp) AS first_msg, MAX(timestamp) AS last_msg FROM chat_logs WHERE timestamp LIKE ? GROUP BY acc_id ORDER BY unique_customers DESC """), (f"{date}%",)).fetchall() return [dict(r) for r in rows] def get_daily_conversations(date: str = "") -> List[Dict]: """ 返回指定日期每个客户的对话摘要(每人最多取前5条消息用于 AI 摘要)。 """ if not date: date = datetime.now().strftime("%Y-%m-%d") with _get_conn() as conn: if _is_mysql(): rows = conn.execute(_sql(""" SELECT acc_id, customer_id, MAX(customer_name) AS customer_name, COUNT(*) AS msg_count, GROUP_CONCAT( CONCAT(CASE WHEN direction='in' THEN '买:' ELSE '客:' END, LEFT(message,40)) SEPARATOR ' | ' ) AS snippet FROM chat_logs WHERE timestamp LIKE ? GROUP BY acc_id, customer_id ORDER BY acc_id, MAX(timestamp) DESC """), (f"{date}%",)).fetchall() else: rows = conn.execute(_sql(""" SELECT acc_id, customer_id, MAX(customer_name) AS customer_name, COUNT(*) AS msg_count, GROUP_CONCAT( CASE WHEN direction='in' THEN '买:' || SUBSTR(message,1,40) ELSE '客:' || SUBSTR(message,1,40) END, ' | ' ) AS snippet FROM chat_logs WHERE timestamp LIKE ? GROUP BY acc_id, customer_id ORDER BY acc_id, MAX(timestamp) DESC """), (f"{date}%",)).fetchall() return [dict(r) for r in rows] def search_messages(keyword: str, customer_id: Optional[str] = None, limit: int = 50) -> List[Dict]: """全文搜索消息""" if customer_id: with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT customer_id, customer_name, direction, message, timestamp FROM chat_logs WHERE customer_id = ? AND message LIKE ? ORDER BY timestamp DESC LIMIT ? """), (customer_id, f"%{keyword}%", limit)).fetchall() else: with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT customer_id, customer_name, direction, message, timestamp FROM chat_logs WHERE message LIKE ? ORDER BY timestamp DESC LIMIT ? """), (f"%{keyword}%", limit)).fetchall() return [dict(r) for r in rows] def get_latest_messages(limit: int = 20) -> List[Dict]: with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT id, customer_id, customer_name, direction, message, timestamp FROM chat_logs ORDER BY id DESC LIMIT ? """), (limit,)).fetchall() return [dict(r) for r in rows] # ========== 订单相关 ========== @_retry_db_operation def upsert_order( customer_id: str, order_id: str, order_status: str = "", acc_id: str = "", product_title: str = "", amount: float = 0.0, quantity: int = 0, buyer_note: str = "", ): """写入或更新一条订单记录(按 order_id + order_status 去重)""" ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with _get_conn() as conn: if _is_mysql(): conn.execute( "INSERT INTO customer_orders " "(customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at) " "VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) " "ON DUPLICATE KEY UPDATE customer_id=VALUES(customer_id), acc_id=VALUES(acc_id), " "product_title=VALUES(product_title), amount=VALUES(amount), quantity=VALUES(quantity), " "buyer_note=VALUES(buyer_note), updated_at=VALUES(updated_at)", (customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, ts, ts), ) else: conn.execute( _sql("INSERT OR REPLACE INTO customer_orders " "(customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at) " "VALUES (?,?,?,?,?,?,?,?,?,?)"), (customer_id, acc_id, order_id, order_status, product_title, amount, quantity, buyer_note, ts, ts), ) conn.commit() @_retry_db_operation def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]: """查询某客户的订单记录(按时间倒序)""" with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at FROM customer_orders WHERE customer_id = ? ORDER BY updated_at DESC LIMIT ? """), (customer_id, limit)).fetchall() return [dict(r) for r in rows] def get_order_by_id(order_id: str) -> List[Dict]: """按订单号查询所有状态变更记录""" with _get_conn() as conn: rows = conn.execute(_sql(""" SELECT customer_id, order_id, order_status, product_title, amount, quantity, buyer_note, created_at, updated_at FROM customer_orders WHERE order_id = ? ORDER BY updated_at ASC """), (order_id,)).fetchall() return [dict(r) for r in rows]