131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
import logging
|
||
import asyncio
|
||
import re
|
||
from typing import Optional, List, Any
|
||
from datetime import datetime
|
||
from db.customer_db import db as customer_db
|
||
from db.image_tasks_db import db as task_db
|
||
from db.chat_log_db import log_message, get_conversation
|
||
|
||
logger = logging.getLogger("cs_agent")
|
||
|
||
_OUTBOUND_BLOCK_MARKERS = (
|
||
"【历史记录摘要】",
|
||
"【详细记录】",
|
||
"【订单摘要】",
|
||
"【订单详情】",
|
||
"<think",
|
||
"think_never_used",
|
||
'[{"name":',
|
||
)
|
||
|
||
_HISTORY_LEAK_PATTERNS = [
|
||
r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*(客户|客服)[::]',
|
||
r'\[\d{2}:\d{2}:\d{2}\]\s*(客户|客服|我)[::]',
|
||
r'(根据|查看|查询|翻看)(历史|聊天|对话)(记录|内容)',
|
||
r'历史(记录|对话|消息)(显示|表明|中)',
|
||
r'之前的(聊天|对话|记录)(中|里|显示)',
|
||
r'共\d+条(历史|对话)?消息',
|
||
r'订单号[::]\s*\d{10,}',
|
||
r'(状态|金额|数量)[::].*(状态|金额|数量)[::]',
|
||
]
|
||
|
||
|
||
def _sanitize_outbound_archive_text(content: str) -> str:
|
||
if not content:
|
||
return ""
|
||
cleaned = str(content).strip()
|
||
if "[转移会话]" in cleaned:
|
||
return cleaned
|
||
if any(marker in cleaned for marker in _OUTBOUND_BLOCK_MARKERS):
|
||
logger.warning("[Repository] 拦截到内部内容写入外发记录,替换为安全兜底回复")
|
||
return "我在帮你看记录,稍等哈"
|
||
for pattern in _HISTORY_LEAK_PATTERNS:
|
||
if re.search(pattern, cleaned):
|
||
logger.warning(f"[Repository] 检测到历史记录泄露模式,拦截出站入库: {pattern[:30]}...")
|
||
return "我在帮你看记录,稍等哈"
|
||
return cleaned
|
||
|
||
class DataRepository:
|
||
"""
|
||
异步数据仓库:使用 asyncio.to_thread 屏蔽底层同步 IO 阻塞。
|
||
"""
|
||
def __init__(self):
|
||
self.customer_db = customer_db
|
||
self.task_db = task_db
|
||
|
||
# --- 聊天记录 (异步化) ---
|
||
|
||
async def save_chat(
|
||
self,
|
||
platform: str,
|
||
user_id: str,
|
||
content: str,
|
||
direction: str,
|
||
acc_id: str = "",
|
||
image_urls: list = None,
|
||
msg_type: int = 0,
|
||
):
|
||
"""异步持久化存储聊天记录"""
|
||
if direction == "out" and int(msg_type or 0) == 0:
|
||
content = _sanitize_outbound_archive_text(content)
|
||
# 将图片URL列表转为\n分隔的字符串
|
||
urls_str = "\n".join(image_urls) if image_urls else ""
|
||
return await asyncio.to_thread(
|
||
log_message,
|
||
customer_id=user_id,
|
||
message=content,
|
||
direction=direction,
|
||
platform=platform,
|
||
acc_id=acc_id,
|
||
msg_type=msg_type,
|
||
image_urls=urls_str
|
||
)
|
||
|
||
async def get_chat_history(self, user_id: str, limit: int = 10, acc_id: str = "") -> List[dict]:
|
||
"""异步获取历史记录"""
|
||
rows = await asyncio.to_thread(get_conversation, user_id, limit=limit, acc_id=acc_id)
|
||
history = []
|
||
for r in rows:
|
||
role = "user" if r["direction"] == "in" else "assistant"
|
||
history.append(
|
||
{
|
||
"role": role,
|
||
"content": r["message"],
|
||
"msg_type": r.get("msg_type", 0),
|
||
"image_urls": r.get("image_urls", ""),
|
||
"timestamp": r.get("timestamp", ""),
|
||
}
|
||
)
|
||
return history
|
||
|
||
# --- 客户相关 (异步化) ---
|
||
|
||
async def get_customer(self, platform: str, user_id: str):
|
||
customer_key = f"{platform}:{user_id}"
|
||
return await asyncio.to_thread(self.customer_db.get_customer, customer_key)
|
||
|
||
# --- 任务相关 (异步化) ---
|
||
|
||
async def create_task(self, platform: str, user_id: str, image_url: str, operation: str, requirements: str = ""):
|
||
return await asyncio.to_thread(
|
||
self.task_db.add_task,
|
||
customer_id=user_id,
|
||
platform=platform,
|
||
original_image=image_url,
|
||
operation=operation,
|
||
requirements=requirements,
|
||
status="pending"
|
||
)
|
||
|
||
async def update_task_price(self, platform: str, user_id: str, price: float):
|
||
"""异步记录成交价"""
|
||
return await asyncio.to_thread(self.task_db.update_price, user_id, platform, price)
|
||
|
||
async def update_task_outcome(self, platform: str, user_id: str, outcome: str):
|
||
"""异步记录最终结局"""
|
||
return await asyncio.to_thread(self.task_db.update_outcome, user_id, platform, outcome)
|
||
|
||
# 全局异步仓库单例
|
||
repo = DataRepository()
|