from __future__ import annotations import json import re import sqlite3 from pathlib import Path from typing import Any from .config import ( MYSQL_DATABASE, MYSQL_HOST, MYSQL_PASSWORD, MYSQL_PORT, MYSQL_TABLE_PREFIX, MYSQL_USER, STORE_BACKEND, STORE_SQLITE_PATH, ) from .state_machine import migrate_state_schema DB_PATH = Path(__file__).resolve().parents[1] / "qingjian_cs.db" def _safe_prefix(v: str) -> str: p = re.sub(r"[^a-zA-Z0-9_]", "", (v or "").strip()) return p or "qjcs_" class ConversationStore: def __init__(self, backend: str | None = None, db_path: str | None = None) -> None: self.backend = (backend or STORE_BACKEND or "sqlite").lower() self.db_path = db_path or STORE_SQLITE_PATH or str(DB_PATH) self.prefix = _safe_prefix(MYSQL_TABLE_PREFIX) self.sessions_table = f"{self.prefix}sessions" self.events_table = f"{self.prefix}events" self._init_db() def _sqlite_conn(self): return sqlite3.connect(self.db_path) def _mysql_conn(self): import pymysql return pymysql.connect( host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER, password=MYSQL_PASSWORD, database=MYSQL_DATABASE, charset="utf8mb4", autocommit=False, ) def _conn(self): if self.backend == "mysql": return self._mysql_conn() return self._sqlite_conn() def _init_db(self) -> None: if self.backend == "mysql": self._init_mysql() else: self._init_sqlite() def _ensure_sqlite_column(self, conn: sqlite3.Connection, table: str, col: str, ddl: str) -> None: cols = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()} if col not in cols: conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}") def _init_sqlite(self) -> None: t_s = self.sessions_table t_e = self.events_table with self._sqlite_conn() as c: c.execute( f""" CREATE TABLE IF NOT EXISTS {t_s} ( customer_key TEXT PRIMARY KEY, acc_id TEXT, customer_id TEXT, route TEXT, state_json TEXT, after_sales_stage TEXT, state_version INTEGER DEFAULT 2, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ) """ ) c.execute( f""" CREATE TABLE IF NOT EXISTS {t_e} ( id INTEGER PRIMARY KEY AUTOINCREMENT, customer_key TEXT, event TEXT, payload_json TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ) """ ) self._ensure_sqlite_column(c, t_s, "after_sales_stage", "after_sales_stage TEXT") self._ensure_sqlite_column(c, t_s, "state_version", "state_version INTEGER DEFAULT 2") def _init_mysql(self) -> None: t_s = self.sessions_table t_e = self.events_table conn = self._mysql_conn() try: with conn.cursor() as c: c.execute( f""" CREATE TABLE IF NOT EXISTS {t_s} ( customer_key VARCHAR(255) PRIMARY KEY, acc_id VARCHAR(255), customer_id VARCHAR(255), route VARCHAR(64), state_json JSON, after_sales_stage VARCHAR(64), state_version INT DEFAULT 2, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, INDEX idx_after_sales_stage (after_sales_stage) ) CHARACTER SET utf8mb4 """ ) c.execute( f""" CREATE TABLE IF NOT EXISTS {t_e} ( id BIGINT PRIMARY KEY AUTO_INCREMENT, customer_key VARCHAR(255), event VARCHAR(128), payload_json JSON, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, INDEX idx_customer_key (customer_key), INDEX idx_event (event) ) CHARACTER SET utf8mb4 """ ) conn.commit() finally: conn.close() def get_session(self, customer_key: str) -> dict[str, Any]: t_s = self.sessions_table conn = self._conn() try: with conn.cursor() as c: if self.backend == "mysql": c.execute( f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=%s", (customer_key,), ) else: c.execute( f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=?", (customer_key,), ) row = c.fetchone() if not row: return {"route": "pre_sales", "state": migrate_state_schema({})} if isinstance(row, dict): vals = [row.get("acc_id"), row.get("customer_id"), row.get("route"), row.get("state_json"), row.get("after_sales_stage"), row.get("state_version")] else: vals = list(row) raw_state = vals[3] try: if isinstance(raw_state, dict): state = raw_state else: state = json.loads(raw_state or "{}") except Exception: state = {} state = migrate_state_schema(state) if vals[4] and not state.get("after_sales_stage"): state["after_sales_stage"] = vals[4] if vals[5] and not state.get("version"): state["version"] = int(vals[5]) return { "acc_id": vals[0], "customer_id": vals[1], "route": vals[2] or "pre_sales", "state": state, } finally: conn.close() def upsert_session(self, customer_key: str, acc_id: str, customer_id: str, route: str, state: dict[str, Any]) -> None: t_s = self.sessions_table state = migrate_state_schema(state) state_json = json.dumps(state or {}, ensure_ascii=False) after_sales_stage = str(state.get("after_sales_stage", "new") or "new") state_version = int(state.get("version", 2) or 2) conn = self._conn() try: with conn.cursor() as c: if self.backend == "mysql": c.execute( f""" INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version) VALUES(%s,%s,%s,%s,%s,%s,%s) ON DUPLICATE KEY UPDATE acc_id=VALUES(acc_id), customer_id=VALUES(customer_id), route=VALUES(route), state_json=VALUES(state_json), after_sales_stage=VALUES(after_sales_stage), state_version=VALUES(state_version) """, (customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version), ) else: c.execute( f""" INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version) VALUES(?,?,?,?,?,?,?) ON CONFLICT(customer_key) DO UPDATE SET acc_id=excluded.acc_id, customer_id=excluded.customer_id, route=excluded.route, state_json=excluded.state_json, after_sales_stage=excluded.after_sales_stage, state_version=excluded.state_version, updated_at=CURRENT_TIMESTAMP """, (customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version), ) conn.commit() finally: conn.close() def append_event(self, customer_key: str, event: str, payload: dict[str, Any]) -> None: t_e = self.events_table payload_json = json.dumps(payload or {}, ensure_ascii=False) conn = self._conn() try: with conn.cursor() as c: if self.backend == "mysql": c.execute( f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(%s,%s,%s)", (customer_key, event, payload_json), ) else: c.execute( f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(?,?,?)", (customer_key, event, payload_json), ) conn.commit() finally: conn.close() def get_recent_dialogue(self, customer_key: str, limit: int = 24) -> list[dict[str, Any]]: """ 拉取最近对话(按 customer_key 隔离,即 店铺+客户)。 仅返回 customer_message / assistant_message 两类事件。 """ t_e = self.events_table n = max(1, int(limit)) conn = self._conn() try: with conn.cursor() as c: if self.backend == "mysql": c.execute( f""" SELECT event, payload_json FROM {t_e} WHERE customer_key=%s AND event IN ('customer_message','assistant_message') ORDER BY id DESC LIMIT %s """, (customer_key, n), ) else: c.execute( f""" SELECT event, payload_json FROM {t_e} WHERE customer_key=? AND event IN ('customer_message','assistant_message') ORDER BY id DESC LIMIT ? """, (customer_key, n), ) rows = c.fetchall() or [] out: list[dict[str, Any]] = [] for row in reversed(rows): if isinstance(row, dict): ev = str(row.get("event", "") or "") raw = row.get("payload_json") else: ev = str(row[0] or "") raw = row[1] try: payload = raw if isinstance(raw, dict) else json.loads(raw or "{}") except Exception: payload = {} msg = str(payload.get("msg", "") or "").strip() if not msg: continue role = "user" if ev == "customer_message" else "assistant" out.append({"role": role, "text": msg}) return out finally: conn.close()