Files
tw2/qingjian_cs/app/store.py
jimi 36e9082d33
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_en (ubuntu-latest, 3.10) (push) Has been cancelled
Deploy Sphinx documentation to Pages / build_zh (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (macos-15, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (ubuntu-latest, 3.12) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.10) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.11) (push) Has been cancelled
Python Unittest Coverage / test (windows-latest, 3.12) (push) Has been cancelled
fix: add recent dialogue loader to conversation store
2026-03-03 12:53:36 +08:00

309 lines
12 KiB
Python

from __future__ import annotations
import json
import re
import sqlite3
from pathlib import Path
from typing import Any
from .config import (
MYSQL_DATABASE,
MYSQL_HOST,
MYSQL_PASSWORD,
MYSQL_PORT,
MYSQL_TABLE_PREFIX,
MYSQL_USER,
STORE_BACKEND,
STORE_SQLITE_PATH,
)
from .state_machine import migrate_state_schema
DB_PATH = Path(__file__).resolve().parents[1] / "qingjian_cs.db"
def _safe_prefix(v: str) -> str:
p = re.sub(r"[^a-zA-Z0-9_]", "", (v or "").strip())
return p or "qjcs_"
class ConversationStore:
def __init__(self, backend: str | None = None, db_path: str | None = None) -> None:
self.backend = (backend or STORE_BACKEND or "sqlite").lower()
self.db_path = db_path or STORE_SQLITE_PATH or str(DB_PATH)
self.prefix = _safe_prefix(MYSQL_TABLE_PREFIX)
self.sessions_table = f"{self.prefix}sessions"
self.events_table = f"{self.prefix}events"
self._init_db()
def _sqlite_conn(self):
return sqlite3.connect(self.db_path)
def _mysql_conn(self):
import pymysql
return pymysql.connect(
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=MYSQL_DATABASE,
charset="utf8mb4",
autocommit=False,
)
def _conn(self):
if self.backend == "mysql":
return self._mysql_conn()
return self._sqlite_conn()
def _init_db(self) -> None:
if self.backend == "mysql":
self._init_mysql()
else:
self._init_sqlite()
def _ensure_sqlite_column(self, conn: sqlite3.Connection, table: str, col: str, ddl: str) -> None:
cols = {row[1] for row in conn.execute(f"PRAGMA table_info({table})").fetchall()}
if col not in cols:
conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}")
def _init_sqlite(self) -> None:
t_s = self.sessions_table
t_e = self.events_table
with self._sqlite_conn() as c:
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_s} (
customer_key TEXT PRIMARY KEY,
acc_id TEXT,
customer_id TEXT,
route TEXT,
state_json TEXT,
after_sales_stage TEXT,
state_version INTEGER DEFAULT 2,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_e} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_key TEXT,
event TEXT,
payload_json TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
self._ensure_sqlite_column(c, t_s, "after_sales_stage", "after_sales_stage TEXT")
self._ensure_sqlite_column(c, t_s, "state_version", "state_version INTEGER DEFAULT 2")
def _init_mysql(self) -> None:
t_s = self.sessions_table
t_e = self.events_table
conn = self._mysql_conn()
try:
with conn.cursor() as c:
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_s} (
customer_key VARCHAR(255) PRIMARY KEY,
acc_id VARCHAR(255),
customer_id VARCHAR(255),
route VARCHAR(64),
state_json JSON,
after_sales_stage VARCHAR(64),
state_version INT DEFAULT 2,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_after_sales_stage (after_sales_stage)
) CHARACTER SET utf8mb4
"""
)
c.execute(
f"""
CREATE TABLE IF NOT EXISTS {t_e} (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
customer_key VARCHAR(255),
event VARCHAR(128),
payload_json JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
INDEX idx_customer_key (customer_key),
INDEX idx_event (event)
) CHARACTER SET utf8mb4
"""
)
conn.commit()
finally:
conn.close()
def get_session(self, customer_key: str) -> dict[str, Any]:
t_s = self.sessions_table
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=%s",
(customer_key,),
)
else:
c.execute(
f"SELECT acc_id, customer_id, route, state_json, after_sales_stage, state_version FROM {t_s} WHERE customer_key=?",
(customer_key,),
)
row = c.fetchone()
if not row:
return {"route": "pre_sales", "state": migrate_state_schema({})}
if isinstance(row, dict):
vals = [row.get("acc_id"), row.get("customer_id"), row.get("route"), row.get("state_json"), row.get("after_sales_stage"), row.get("state_version")]
else:
vals = list(row)
raw_state = vals[3]
try:
if isinstance(raw_state, dict):
state = raw_state
else:
state = json.loads(raw_state or "{}")
except Exception:
state = {}
state = migrate_state_schema(state)
if vals[4] and not state.get("after_sales_stage"):
state["after_sales_stage"] = vals[4]
if vals[5] and not state.get("version"):
state["version"] = int(vals[5])
return {
"acc_id": vals[0],
"customer_id": vals[1],
"route": vals[2] or "pre_sales",
"state": state,
}
finally:
conn.close()
def upsert_session(self, customer_key: str, acc_id: str, customer_id: str, route: str, state: dict[str, Any]) -> None:
t_s = self.sessions_table
state = migrate_state_schema(state)
state_json = json.dumps(state or {}, ensure_ascii=False)
after_sales_stage = str(state.get("after_sales_stage", "new") or "new")
state_version = int(state.get("version", 2) or 2)
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"""
INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version)
VALUES(%s,%s,%s,%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE
acc_id=VALUES(acc_id),
customer_id=VALUES(customer_id),
route=VALUES(route),
state_json=VALUES(state_json),
after_sales_stage=VALUES(after_sales_stage),
state_version=VALUES(state_version)
""",
(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version),
)
else:
c.execute(
f"""
INSERT INTO {t_s}(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version)
VALUES(?,?,?,?,?,?,?)
ON CONFLICT(customer_key) DO UPDATE SET
acc_id=excluded.acc_id,
customer_id=excluded.customer_id,
route=excluded.route,
state_json=excluded.state_json,
after_sales_stage=excluded.after_sales_stage,
state_version=excluded.state_version,
updated_at=CURRENT_TIMESTAMP
""",
(customer_key, acc_id, customer_id, route, state_json, after_sales_stage, state_version),
)
conn.commit()
finally:
conn.close()
def append_event(self, customer_key: str, event: str, payload: dict[str, Any]) -> None:
t_e = self.events_table
payload_json = json.dumps(payload or {}, ensure_ascii=False)
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(%s,%s,%s)",
(customer_key, event, payload_json),
)
else:
c.execute(
f"INSERT INTO {t_e}(customer_key, event, payload_json) VALUES(?,?,?)",
(customer_key, event, payload_json),
)
conn.commit()
finally:
conn.close()
def get_recent_dialogue(self, customer_key: str, limit: int = 24) -> list[dict[str, Any]]:
"""
拉取最近对话(按 customer_key 隔离,即 店铺+客户)。
仅返回 customer_message / assistant_message 两类事件。
"""
t_e = self.events_table
n = max(1, int(limit))
conn = self._conn()
try:
with conn.cursor() as c:
if self.backend == "mysql":
c.execute(
f"""
SELECT event, payload_json
FROM {t_e}
WHERE customer_key=%s
AND event IN ('customer_message','assistant_message')
ORDER BY id DESC
LIMIT %s
""",
(customer_key, n),
)
else:
c.execute(
f"""
SELECT event, payload_json
FROM {t_e}
WHERE customer_key=?
AND event IN ('customer_message','assistant_message')
ORDER BY id DESC
LIMIT ?
""",
(customer_key, n),
)
rows = c.fetchall() or []
out: list[dict[str, Any]] = []
for row in reversed(rows):
if isinstance(row, dict):
ev = str(row.get("event", "") or "")
raw = row.get("payload_json")
else:
ev = str(row[0] or "")
raw = row[1]
try:
payload = raw if isinstance(raw, dict) else json.loads(raw or "{}")
except Exception:
payload = {}
msg = str(payload.get("msg", "") or "").strip()
if not msg:
continue
role = "user" if ev == "customer_message" else "assistant"
out.append({"role": role, "text": msg})
return out
finally:
conn.close()