Files
tw/core/orchestrator.py

157 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import asyncio
import re
import time
from typing import Optional, List, Any, Dict
from collections import deque
from core.schema import StandardMessage, StandardResponse
from core.adapters.qianniu_adapter import QianniuAdapter
from core.pydantic_ai_agent_v2 import CustomerServiceBrain
from core.events.event_bus import bus
from core.repository import repo
logger = logging.getLogger("cs_agent")
class SystemOrchestrator:
"""
全系统总编排:具备转接冷却、防抖合并、多消息去重、以及精准日志。
"""
def __init__(self, ws_client=None):
self.ws_client = ws_client
self.qianniu_adapter = QianniuAdapter(ws_client)
self.brain = CustomerServiceBrain()
# 1. 消息 ID 去重
self._processed_msg_ids = deque(maxlen=200)
# 2. 转接冷却存储 (customer_id -> last_transfer_time)
self._last_transfer_time: Dict[str, float] = {}
# 3. 防抖配置
self._debounce_seconds = 5.0
self._debounce_tasks: Dict[str, asyncio.Task] = {}
self._pending_messages: Dict[str, List[StandardMessage]] = {}
self._user_locks: Dict[str, asyncio.Lock] = {}
bus.subscribe("MESSAGE_OUTBOUND", self.handle_outbound_event)
def _get_user_lock(self, user_id: str) -> asyncio.Lock:
if user_id not in self._user_locks:
self._user_locks[user_id] = asyncio.Lock()
return self._user_locks[user_id]
async def on_raw_message_received(self, platform: str, raw_data: dict):
"""链路入口"""
try:
if platform != "qianniu": return
std_msg, direction = await self.qianniu_adapter.translate_inbound(raw_data)
# 过滤心跳
if not std_msg.content.strip() and not std_msg.image_urls: return
# 如果是商家人工回复,静默入库
if direction == "out":
await repo.save_chat(platform, std_msg.user_id, std_msg.content, "out", acc_id=std_msg.acc_id)
return
# 订单消息处理:静默记录
if "[系统订单信息]" in std_msg.content:
await self._handle_order_packet(platform, std_msg)
await repo.save_chat(platform, std_msg.user_id, std_msg.content, "in", acc_id=std_msg.acc_id)
return
# ID 去重
if std_msg.msg_id:
if std_msg.msg_id in self._processed_msg_ids: return
self._processed_msg_ids.append(std_msg.msg_id)
# 进入防抖
user_id = std_msg.user_id
if user_id in self._debounce_tasks: self._debounce_tasks[user_id].cancel()
if user_id not in self._pending_messages: self._pending_messages[user_id] = []
self._pending_messages[user_id].append(std_msg)
self._debounce_tasks[user_id] = asyncio.create_task(self._debounced_process(user_id, platform))
except Exception as e:
logger.error(f"[Orchestrator] 处理失败: {e}")
async def _handle_order_packet(self, platform: str, msg: StandardMessage):
try:
price_match = re.search(r"订单金额:金额:\s*([\d\.]+)元", msg.content)
if price_match: await repo.update_task_price(platform, msg.user_id, float(price_match.group(1)))
if "买家已付款" in msg.content: await repo.update_task_outcome(platform, msg.user_id, "deal_success")
elif any(k in msg.content for k in ["退款", "已关闭", "已取消"]): await repo.update_task_outcome(platform, msg.user_id, "refunded")
except Exception: pass
async def _debounced_process(self, user_id: str, platform: str):
try:
await asyncio.sleep(self._debounce_seconds)
async with self._get_user_lock(user_id):
messages = self._pending_messages.pop(user_id, [])
if not messages: return
# A. 合并与元数据修复
combined_content = "\n".join([m.content for m in messages if m.content.strip()])
all_image_urls = []
acc_id = messages[-1].acc_id
acc_type = messages[-1].acc_type
for m in messages:
for url in m.image_urls:
if url not in all_image_urls: all_image_urls.append(url)
# 防抖合并后的消息仍需有 msg_id避免触发 StandardMessage 校验失败
merged_msg_id = messages[-1].msg_id if messages[-1].msg_id else f"merged-{user_id}-{int(time.time() * 1000)}"
final_msg = StandardMessage(
platform=platform,
msg_id=merged_msg_id,
user_id=user_id,
content=combined_content,
image_urls=all_image_urls,
acc_id=acc_id,
acc_type=acc_type
)
# B. 持久化
db_content = combined_content
if all_image_urls: db_content = f"【系统:已收到{len(all_image_urls)}张图】\n{combined_content}"
await repo.save_chat(platform, user_id, db_content, "in", acc_id=acc_id)
# C. 冷却检查:如果 60秒内发过转接告诉大脑“已处于转接中”
is_in_cooldown = (time.time() - self._last_transfer_time.get(user_id, 0)) < 60
# D. 思考
history = await repo.get_chat_history(user_id, limit=10)
if history and history[-1]['content'] == db_content: history = history[:-1]
# 如果在冷却中,在当前消息里注入“当前已在转接中”的信息
if is_in_cooldown:
final_msg.content = f"【系统:当前已向设计师发出转接请求,请勿再次调用转接工具】\n{final_msg.content}"
std_res = await self.brain.think_and_reply(final_msg, history=history)
# E. 发送并记录时间
if std_res.should_reply:
# 关键修复:补全发送时的元数据,解决日志 customer_id 为空的问题
std_res.metadata = {"acc_id": acc_id, "acc_type": acc_type}
await self.qianniu_adapter.translate_outbound(std_res, user_id)
await repo.save_chat(platform, user_id, std_res.reply_content, "out", acc_id=acc_id)
if "[转移会话]" in std_res.reply_content:
self._last_transfer_time[user_id] = time.time()
except asyncio.CancelledError: pass
except Exception as e: logger.exception(f"[Orchestrator] 处理失败: {e}")
async def handle_outbound_event(self, user_id: str, platform: str, response: StandardResponse):
if platform == "qianniu":
await self.qianniu_adapter.translate_outbound(response, user_id)
# 全局单例
orchestrator: Optional[SystemOrchestrator] = None
def init_orchestrator(ws_client):
global orchestrator
orchestrator = SystemOrchestrator(ws_client)
return orchestrator