From 57dd967d58aaf439d5199f9f90ab3f19120211ec Mon Sep 17 00:00:00 2001 From: jimi <1847930177@qq.com> Date: Sun, 1 Mar 2026 17:50:59 +0800 Subject: [PATCH] feat: add full-context AI outbound reply guard before send --- core/websocket_client.py | 128 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 125 insertions(+), 3 deletions(-) diff --git a/core/websocket_client.py b/core/websocket_client.py index e9ce953..e1f707e 100755 --- a/core/websocket_client.py +++ b/core/websocket_client.py @@ -116,13 +116,14 @@ from utils.metrics_tracker import emit as metrics_emit # 导入 Agent 模块 try: - from core.pydantic_ai_agent import CustomerServiceAgent, CustomerMessage, _get_shop_type + from core.pydantic_ai_agent import CustomerServiceAgent, CustomerMessage, AgentDeps, _get_shop_type from db.customer_db import db from core.workflow import workflow AGENT_AVAILABLE = True except Exception as e: AGENT_AVAILABLE = False workflow = None + AgentDeps = None _get_shop_type = lambda acc_id, goods_name: "find_image" import traceback logger.info(f"警告: Agent 模块导入失败: {e}") @@ -160,6 +161,7 @@ class QingjianAPIClient: self._pending_images: dict = {} self._pending_image_tasks: dict = {} self._auto_quote_tasks: dict = {} # customer_key -> asyncio.Task + self._auto_quote_done_sig: dict = {} # customer_key -> signature(同一批内容仅自动触发一次) # 旧版“看图即报价”快速链路(默认关闭,避免与 Agent 批量收集逻辑并发打架) self._legacy_fast_quote_enabled = os.getenv("LEGACY_FAST_IMAGE_QUOTE", "false").lower() in ("1", "true", "yes") self._system_inquiry_rules = self._load_system_inquiry_rules() @@ -171,6 +173,7 @@ class QingjianAPIClient: or "http://139.199.3.75:18789/api/callback" ) self._tianwang_agent_name = os.getenv("TIANWANG_AGENT_NAME", "终结者").strip() or "终结者" + self._reply_guard_enabled = os.getenv("AI_REPLY_GUARD_ENABLED", "true").lower() in ("1", "true", "yes") # 延迟加载任务模块(避免循环导入) self.task_scheduler = None @@ -1143,6 +1146,14 @@ class QingjianAPIClient: task.cancel() self._activity_log("auto_quote_cancel", key=key, reason=reason or "unknown") + @staticmethod + def _build_auto_quote_signature(state: Any) -> str: + """为待报价内容生成稳定签名,用于避免同一批内容反复自动触发。""" + urls = list(getattr(state, "pending_image_urls", []) or []) + reqs = list(getattr(state, "pending_requirements", []) or []) + req_tail = reqs[-6:] if len(reqs) > 6 else reqs + return "||".join(urls) + "##" + "||".join(req_tail) + async def _maybe_schedule_auto_quote(self, data: dict): """ 智能兜底:客户发图后若长时间不再补充消息,自动触发一次报价,避免会话卡住。 @@ -1158,9 +1169,18 @@ class QingjianAPIClient: state = self.agent._get_conversation_state(cid) if not state or not getattr(state, "pending_image_urls", None): self._cancel_auto_quote_task(key, reason="no_pending_images") + self._auto_quote_done_sig.pop(key, None) return if state.quote_phase not in {"collecting", "waiting_result"}: return + current_sig = self._build_auto_quote_signature(state) + if current_sig and self._auto_quote_done_sig.get(key) == current_sig: + self._activity_log( + "auto_quote_skip_duplicate", + key=key, + pending_count=len(state.pending_image_urls), + ) + return try: idle_seconds = max(8, int(os.getenv("AUTO_QUOTE_IDLE_SECONDS", "18"))) except Exception: @@ -1168,13 +1188,19 @@ class QingjianAPIClient: self._cancel_auto_quote_task(key, reason="reschedule") - async def _delayed_auto_quote(capture_key: str, capture_data: dict, wait_s: int): + async def _delayed_auto_quote(capture_key: str, capture_data: dict, wait_s: int, capture_sig: str): await asyncio.sleep(wait_s) async with self._get_customer_lock(capture_key): capture_cid = capture_data.get('from_id', '') st = self.agent._get_conversation_state(capture_cid) if not st or not st.pending_image_urls: + self._auto_quote_done_sig.pop(capture_key, None) return + # 内容变化时,放弃旧触发(会在新一轮消息后重新调度)。 + if self._build_auto_quote_signature(st) != capture_sig: + return + # 标记本批次已自动触发,避免同内容循环“马上报价”。 + self._auto_quote_done_sig[capture_key] = capture_sig # 直接置为可报价,然后走“发完了,报价吧”触发既有报价链路 self.agent._mark_quote_ready(st) self.agent._sync_pending_quote_state(capture_cid, st) @@ -1206,7 +1232,7 @@ class QingjianAPIClient: reply=response.reply, ) - task = asyncio.create_task(_delayed_auto_quote(key, dict(data), idle_seconds)) + task = asyncio.create_task(_delayed_auto_quote(key, dict(data), idle_seconds, current_sig)) self._auto_quote_tasks[key] = task self._activity_log( "auto_quote_scheduled", @@ -1888,6 +1914,25 @@ class QingjianAPIClient: customer_id = original_msg.get("from_id", "") customer_name = original_msg.get("from_name", "") + allow_send, checked_reply, guard_reason = await self._ai_guard_outbound_reply( + original_msg=original_msg, + reply_content=str(reply_content), + ) + self._activity_log( + "reply_guard_decision", + trace_id=trace_id, + acc_id=shop_id, + customer_id=customer_id, + result="ok" if allow_send else "blocked", + reason=guard_reason, + original_reply=str(reply_content), + final_reply=str(checked_reply or ""), + ) + if not allow_send: + logger.info(f"回复被AI质检拦截: {guard_reason}") + return + reply_content = checked_reply or str(reply_content) + reply = { "msg_id": "", "acc_id": shop_id, @@ -1954,6 +1999,83 @@ class QingjianAPIClient: out = out.replace("。", "。") return out + async def _ai_guard_outbound_reply(self, original_msg: dict, reply_content: str) -> tuple[bool, str, str]: + """ + 专用AI质检:发送前判断“这句是否该发”,可拦截或改写。 + 读取当前客户在当前店铺的完整对话上下文。 + """ + text = (reply_content or "").strip() + if not text: + return False, "", "empty_reply" + if text.startswith("话术|") or "[转移会话]" in text: + return True, text, "command_bypass" + if not self._reply_guard_enabled or not self.enable_agent or not self.agent or not AgentDeps: + return True, text, "guard_disabled" + try: + from db.chat_log_db import get_conversation + + acc_id = str(original_msg.get("acc_id", "") or "") + customer_id = str(original_msg.get("from_id", "") or "") + if not customer_id: + return True, text, "no_customer_id" + + # 默认读取较大窗口,尽量覆盖完整上下文;可用环境变量继续放大。 + try: + max_rows = max(50, int(os.getenv("AI_REPLY_GUARD_CONTEXT_ROWS", "500"))) + except Exception: + max_rows = 500 + rows = get_conversation(customer_id=customer_id, limit=max_rows) or [] + shop_rows = [r for r in rows if str(r.get("acc_id", "") or "") == acc_id] if acc_id else rows + + context_lines = [] + for r in shop_rows: + role = "客" if (r.get("direction") == "in") else "服" + msg = self.to_chinese((r.get("message") or "").strip()) + if msg: + context_lines.append(f"{role}:{msg}") + context_text = "\n".join(context_lines) if context_lines else "无历史" + + deps = AgentDeps( + msg_id=str(original_msg.get("msg_id", "") or "reply_guard"), + acc_id=acc_id, + from_id=customer_id, + platform=str(original_msg.get("acc_type", "") or ""), + ) + prompt = ( + "你是淘宝客服回复质检器。目标:判断候选回复是否和上下文一致,是否会造成重复触发式答复。\n" + "必须检查:\n" + "1) 是否答非所问;\n" + "2) 是否重复说“马上报价/继续发图”但当前上下文不需要;\n" + "3) 是否与历史状态冲突;\n" + "4) 语气是否自然可直接发给客户。\n" + "若不合适,给可直接发送的一句改写。\n" + "只输出 JSON:{\"allow\":true/false,\"rewrite\":\"...\",\"reason\":\"...\"}\n\n" + f"完整上下文(当前店铺):\n{context_text}\n\n" + f"客户当前消息:{self.to_chinese(original_msg.get('msg', '') or '')}\n" + f"候选回复:{text}\n" + ) + result = await self.agent.agent_natural_reply.run(prompt, deps=deps, message_history=[]) + raw = str(getattr(result, "output", "") or "").strip() + if not raw: + return True, text, "guard_empty_output" + import json as _json + import re as _re + + m = _re.search(r"\{[\s\S]*\}", raw) + if not m: + return True, text, "guard_non_json" + obj = _json.loads(m.group(0)) + allow = bool(obj.get("allow", True)) + rewrite = str(obj.get("rewrite", "") or "").strip() + reason = str(obj.get("reason", "") or "").strip() or "guard_decision" + if allow: + return True, (rewrite or text), reason + if rewrite: + return True, rewrite, reason + return False, "", reason + except Exception as e: + return True, text, f"guard_error:{e}" + async def send_text(self, cy_id, acc_type, content): """ 主动发送文本消息