import logging import asyncio import re from typing import Optional, List, Any from datetime import datetime from db.customer_db import db as customer_db from db.image_tasks_db import db as task_db from db.chat_log_db import log_message, get_conversation logger = logging.getLogger("cs_agent") _OUTBOUND_BLOCK_MARKERS = ( "【历史记录摘要】", "【详细记录】", "【订单摘要】", "【订单详情】", " str: if not content: return "" cleaned = str(content).strip() if "[转移会话]" in cleaned: return cleaned if any(marker in cleaned for marker in _OUTBOUND_BLOCK_MARKERS): logger.warning("[Repository] 拦截到内部内容写入外发记录,替换为安全兜底回复") return "我在帮你看记录,稍等哈" for pattern in _HISTORY_LEAK_PATTERNS: if re.search(pattern, cleaned): logger.warning(f"[Repository] 检测到历史记录泄露模式,拦截出站入库: {pattern[:30]}...") return "我在帮你看记录,稍等哈" return cleaned class DataRepository: """ 异步数据仓库:使用 asyncio.to_thread 屏蔽底层同步 IO 阻塞。 """ def __init__(self): self.customer_db = customer_db self.task_db = task_db # --- 聊天记录 (异步化) --- async def save_chat( self, platform: str, user_id: str, content: str, direction: str, acc_id: str = "", image_urls: list = None, msg_type: int = 0, ): """异步持久化存储聊天记录""" if direction == "out" and int(msg_type or 0) == 0: content = _sanitize_outbound_archive_text(content) # 将图片URL列表转为\n分隔的字符串 urls_str = "\n".join(image_urls) if image_urls else "" return await asyncio.to_thread( log_message, customer_id=user_id, message=content, direction=direction, platform=platform, acc_id=acc_id, msg_type=msg_type, image_urls=urls_str ) async def get_chat_history(self, user_id: str, limit: int = 10, acc_id: str = "") -> List[dict]: """异步获取历史记录""" rows = await asyncio.to_thread(get_conversation, user_id, limit=limit, acc_id=acc_id) history = [] for r in rows: role = "user" if r["direction"] == "in" else "assistant" history.append( { "role": role, "content": r["message"], "msg_type": r.get("msg_type", 0), "image_urls": r.get("image_urls", ""), "timestamp": r.get("timestamp", ""), } ) return history # --- 客户相关 (异步化) --- async def get_customer(self, platform: str, user_id: str): customer_key = f"{platform}:{user_id}" return await asyncio.to_thread(self.customer_db.get_customer, customer_key) # --- 任务相关 (异步化) --- async def create_task(self, platform: str, user_id: str, image_url: str, operation: str, requirements: str = ""): return await asyncio.to_thread( self.task_db.add_task, customer_id=user_id, platform=platform, original_image=image_url, operation=operation, requirements=requirements, status="pending" ) async def update_task_price(self, platform: str, user_id: str, price: float): """异步记录成交价""" return await asyncio.to_thread(self.task_db.update_price, user_id, platform, price) async def update_task_outcome(self, platform: str, user_id: str, outcome: str): """异步记录最终结局""" return await asyncio.to_thread(self.task_db.update_outcome, user_id, platform, outcome) # 全局异步仓库单例 repo = DataRepository()