Files
tw/core/repository.py

135 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import asyncio
import re
from typing import Optional, List, Any
from datetime import datetime
from db.customer_db import db as customer_db
from db.image_tasks_db import db as task_db
from db.chat_log_db import log_message, get_conversation
logger = logging.getLogger("cs_agent")
_OUTBOUND_BLOCK_MARKERS = (
"【历史记录摘要】",
"【详细记录】",
"【订单摘要】",
"【订单详情】",
"<think",
"think_never_used",
'[{"name":',
)
_TRANSFER_COMMAND_RE = re.compile(r"^\s*正在为您转接\|\[转移会话\],[^,\r\n]+,[^\r\n]*\s*$")
_HISTORY_LEAK_PATTERNS = [
r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*(客户|客服)[:]',
r'\[\d{2}:\d{2}:\d{2}\]\s*(客户|客服|我)[:]',
r'(根据|查看|查询|翻看)(历史|聊天|对话)(记录|内容)',
r'历史(记录|对话|消息)(显示|表明|中)',
r'之前的(聊天|对话|记录)(中|里|显示)',
r'\d+条(历史|对话)?消息',
r'订单号[:]\s*\d{10,}',
r'(状态|金额|数量)[:].*(状态|金额|数量)[:]',
]
def _sanitize_outbound_archive_text(content: str) -> str:
if not content:
return ""
cleaned = str(content).strip()
if _TRANSFER_COMMAND_RE.fullmatch(cleaned):
return cleaned
if "[转移会话]" in cleaned:
logger.warning("[Repository] 检测到混入正文的转接指令,拦截出站入库")
return "我在帮你看记录,稍等哈"
if any(marker in cleaned for marker in _OUTBOUND_BLOCK_MARKERS):
logger.warning("[Repository] 拦截到内部内容写入外发记录,替换为安全兜底回复")
return "我在帮你看记录,稍等哈"
for pattern in _HISTORY_LEAK_PATTERNS:
if re.search(pattern, cleaned):
logger.warning(f"[Repository] 检测到历史记录泄露模式,拦截出站入库: {pattern[:30]}...")
return "我在帮你看记录,稍等哈"
return cleaned
class DataRepository:
"""
异步数据仓库:使用 asyncio.to_thread 屏蔽底层同步 IO 阻塞。
"""
def __init__(self):
self.customer_db = customer_db
self.task_db = task_db
# --- 聊天记录 (异步化) ---
async def save_chat(
self,
platform: str,
user_id: str,
content: str,
direction: str,
acc_id: str = "",
image_urls: list = None,
msg_type: int = 0,
):
"""异步持久化存储聊天记录"""
if direction == "out" and int(msg_type or 0) == 0:
content = _sanitize_outbound_archive_text(content)
# 将图片URL列表转为\n分隔的字符串
urls_str = "\n".join(image_urls) if image_urls else ""
return await asyncio.to_thread(
log_message,
customer_id=user_id,
message=content,
direction=direction,
platform=platform,
acc_id=acc_id,
msg_type=msg_type,
image_urls=urls_str
)
async def get_chat_history(self, user_id: str, limit: int = 10, acc_id: str = "") -> List[dict]:
"""异步获取历史记录"""
rows = await asyncio.to_thread(get_conversation, user_id, limit=limit, acc_id=acc_id)
history = []
for r in rows:
role = "user" if r["direction"] == "in" else "assistant"
history.append(
{
"role": role,
"content": r["message"],
"msg_type": r.get("msg_type", 0),
"image_urls": r.get("image_urls", ""),
"timestamp": r.get("timestamp", ""),
}
)
return history
# --- 客户相关 (异步化) ---
async def get_customer(self, platform: str, user_id: str):
customer_key = f"{platform}:{user_id}"
return await asyncio.to_thread(self.customer_db.get_customer, customer_key)
# --- 任务相关 (异步化) ---
async def create_task(self, platform: str, user_id: str, image_url: str, operation: str, requirements: str = ""):
return await asyncio.to_thread(
self.task_db.add_task,
customer_id=user_id,
platform=platform,
original_image=image_url,
operation=operation,
requirements=requirements,
status="pending"
)
async def update_task_price(self, platform: str, user_id: str, price: float):
"""异步记录成交价"""
return await asyncio.to_thread(self.task_db.update_price, user_id, platform, price)
async def update_task_outcome(self, platform: str, user_id: str, outcome: str):
"""异步记录最终结局"""
return await asyncio.to_thread(self.task_db.update_outcome, user_id, platform, outcome)
# 全局异步仓库单例
repo = DataRepository()