diff --git a/core/agent_tools.py b/core/agent_tools.py index ec08a1a..9a04daa 100644 --- a/core/agent_tools.py +++ b/core/agent_tools.py @@ -11,6 +11,13 @@ from db.chat_log_db import get_conversation, get_customer_orders logger = logging.getLogger("cs_agent") +class TransferSuccessException(Exception): + """转接成功后抛出此异常,用于提前终止 AI 处理流程""" + def __init__(self, transfer_cmd: str): + self.transfer_cmd = transfer_cmd + super().__init__(transfer_cmd) + + async def transfer_to_human_tool(ctx: RunContext[Any], reason: str = Field(description="转人工的原因")) -> str: """ 【核心工具】执行转人工逻辑。 @@ -22,8 +29,9 @@ async def transfer_to_human_tool(ctx: RunContext[Any], reason: str = Field(descr if designer_name: magic_cmd = f"正在为您转接|[转移会话],{designer_name},无原因" - logger.info(f"[Tool] 成功呼叫设计师: {designer_name}") - return magic_cmd + logger.info(f"[Tool] 成功呼叫设计师: {designer_name},立即触发转接") + # 抛出异常以提前终止 AI 后续处理,节省等待时间 + raise TransferSuccessException(magic_cmd) else: hour = datetime.now().hour logger.warning(f"[Tool] 派单失败:设计师们不在位 (当前{hour}点)") diff --git a/core/orchestrator.py b/core/orchestrator.py index fa10acb..59f955d 100644 --- a/core/orchestrator.py +++ b/core/orchestrator.py @@ -37,6 +37,18 @@ _OUTBOUND_BLOCK_MARKERS = ( '[{"name":', ) +# 历史记录格式检测模式(AI 转述历史时容易泄露) +_HISTORY_LEAK_PATTERNS = [ + r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*(客户|客服)[::]', # [2026-03-07 12:00:00] 客户: + r'\[\d{2}:\d{2}:\d{2}\]\s*(客户|客服|我)[::]', # [12:00:00] 客户: + r'(根据|查看|查询|翻看)(历史|聊天|对话)(记录|内容)', # 根据历史记录 + r'历史(记录|对话|消息)(显示|表明|中)', # 历史记录显示 + r'之前的(聊天|对话|记录)(中|里|显示)', # 之前的聊天中 + r'共\d+条(历史|对话)?消息', # 共30条历史消息 + r'订单号[::]\s*\d{10,}', # 订单号:xxxxxxxxxx + r'(状态|金额|数量)[::].*(状态|金额|数量)[::]', # 状态:xxx 金额:xxx 连续出现 +] + class SystemOrchestrator: """ 全系统总编排:具备转接冷却、防抖合并、多消息去重、以及精准日志。 @@ -84,6 +96,11 @@ class SystemOrchestrator: if any(marker in cleaned for marker in _OUTBOUND_BLOCK_MARKERS): logger.warning("[Orchestrator] 拦截到内部内容外发,替换为安全兜底回复") return "我在帮你看记录,稍等哈" + # 检查历史记录泄露模式 + for pattern in _HISTORY_LEAK_PATTERNS: + if re.search(pattern, cleaned): + logger.warning(f"[Orchestrator] 检测到历史记录泄露模式: {pattern[:30]}...") + return "我在帮你看记录,稍等哈" return cleaned async def on_raw_message_received(self, platform: str, raw_data: dict): @@ -243,11 +260,17 @@ class SystemOrchestrator: async def _debounced_process(self, session_key: str, user_id: str, platform: str): try: + # 记录开始时间(防抖前) + process_start = time.time() + await asyncio.sleep(self._debounce_seconds) async with self._get_user_lock(session_key): messages = self._pending_messages.pop(session_key, []) if not messages: return + debounce_elapsed = time.time() - process_start + logger.info(f"[计时] user={user_id} 防抖等待完成: {debounce_elapsed:.1f}s") + # A. 合并与元数据修复(去重:同一防抖窗口内完全相同的内容只保留一条) seen_contents = set() unique_parts = [] @@ -278,9 +301,12 @@ class SystemOrchestrator: ) # B. 持久化 + db_start = time.time() db_content = combined_content if all_image_urls: db_content = f"【系统:已收到{len(all_image_urls)}张图】\n{combined_content}" await repo.save_chat(platform, user_id, db_content, "in", acc_id=acc_id, image_urls=all_image_urls) + db_elapsed = time.time() - db_start + logger.info(f"[计时] user={user_id} 消息入库: {db_elapsed:.2f}s") # B2. 后台图片分析(不阻塞主流程,用于数据标定) if all_image_urls: @@ -302,9 +328,17 @@ class SystemOrchestrator: ) else: # D. 正常流程:调用AI思考 + history_start = time.time() history = await repo.get_chat_history(user_id, limit=10, acc_id=acc_id) if history and history[-1].get('content') == db_content: history = history[:-1] + history_elapsed = time.time() - history_start + logger.info(f"[计时] user={user_id} 查询历史: {history_elapsed:.2f}s (共{len(history)}条)") + + ai_start = time.time() std_res = await self.brain.think_and_reply(final_msg, history=history) + ai_elapsed = time.time() - ai_start + total_elapsed = time.time() - process_start + logger.info(f"[计时] user={user_id} AI思考: {ai_elapsed:.1f}s | 总耗时: {total_elapsed:.1f}s") # E. 发送并记录时间 if std_res.should_reply: diff --git a/core/pydantic_ai_agent_v2.py b/core/pydantic_ai_agent_v2.py index 952245a..fdb2b64 100644 --- a/core/pydantic_ai_agent_v2.py +++ b/core/pydantic_ai_agent_v2.py @@ -2,15 +2,20 @@ import os import re import hashlib import logging +import time from typing import List, Optional, Any, Dict from pydantic_ai import Agent, RunContext from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.openai import OpenAIProvider from core.schema import StandardMessage, StandardResponse -from core.agent_tools import register_agent_tools +from core.agent_tools import register_agent_tools, TransferSuccessException logger = logging.getLogger("cs_agent") +# 日志详细程度:设置环境变量 AI_LOG_LEVEL=debug 可获得完整日志 +_LOG_FULL_PROMPT = os.getenv("AI_LOG_LEVEL", "").lower() == "debug" +_LOG_CLIP_LIMIT = int(os.getenv("AI_LOG_CLIP", "2000")) # 日志截断长度 + from core.skill_manager import skill_manager @@ -21,6 +26,18 @@ _INTERNAL_TOOL_MARKERS = ( "【订单详情】", ) +# 历史记录格式检测模式(AI 转述历史时容易泄露) +_HISTORY_LEAK_PATTERNS = [ + r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*(客户|客服)[::]', # [2026-03-07 12:00:00] 客户: + r'\[\d{2}:\d{2}:\d{2}\]\s*(客户|客服|我)[::]', # [12:00:00] 客户: + r'(根据|查看|查询|翻看)(历史|聊天|对话)(记录|内容)', # 根据历史记录 + r'历史(记录|对话|消息)(显示|表明|中)', # 历史记录显示 + r'之前的(聊天|对话|记录)(中|里|显示)', # 之前的聊天中 + r'共\d+条(历史|对话)?消息', # 共30条历史消息 + r'订单号[::]\s*\d{10,}', # 订单号:xxxxxxxxxx + r'(状态|金额|数量)[::].*(状态|金额|数量)[::]', # 状态:xxx 金额:xxx 连续出现 +] + def _clip(text: str, limit: int = 1200) -> str: if text is None: @@ -61,10 +78,17 @@ def _sanitize_reply_text(reply_text: str) -> str: text = re.sub(r'[\[\]]{2,}', '', text) text = text.strip() + # 检查固定标记 if any(marker in text for marker in _INTERNAL_TOOL_MARKERS): logger.warning("[Brain] 拦截到工具原文泄露,降级为安全兜底回复") return "我在帮你看记录,稍等哈" + # 检查历史记录泄露模式(AI 转述历史内容) + for pattern in _HISTORY_LEAK_PATTERNS: + if re.search(pattern, text): + logger.warning(f"[Brain] 检测到历史记录泄露模式: {pattern[:30]}...") + return "我在帮你看记录,稍等哈" + return text.strip() @@ -187,12 +211,51 @@ class CustomerServiceBrain: recent_context = "【近期对话回顾】\n" + "\n".join(lines) + "\n----------------\n" full_input = f"【当前客户ID:{msg.user_id}】\n{recent_context}现在的对话:{user_content}" - logger.info( - f"[PROMPT->AI] user={msg.user_id} acc={msg.acc_id} images={len(msg.image_urls)}\n" - f"{_clip(full_input)}" - ) + start_time = time.time() + + # ===== 详细日志:发给 AI 的提示词 ===== + logger.info(f"[AI提示词] user={msg.user_id} acc={msg.acc_id} images={len(msg.image_urls)}\n{full_input}") + if history: + history_preview = "\n".join([f" {h.get('role','?')}: {str(h.get('content',''))[:50]}" for h in history[-4:]]) + logger.info(f"[AI历史上下文] 共{len(history)}条:\n{history_preview}") - result = await self.agent.run(full_input, message_history=history) + # 尝试运行 AI,捕获转接成功异常以提前终止 + try: + result = await self.agent.run(full_input, message_history=history) + except TransferSuccessException as e: + # 转接工具成功后立即返回,无需等待 AI 继续生成 + elapsed = time.time() - start_time + logger.info(f"[Brain] 转接成功(提前终止,耗时{elapsed:.1f}s): {e.transfer_cmd[:60]}") + return StandardResponse( + reply_content=e.transfer_cmd, + need_transfer=True, + metadata={"acc_id": msg.acc_id, "acc_type": msg.acc_type} + ) + + elapsed = time.time() - start_time + logger.info(f"[Brain] AI处理完成,总耗时{elapsed:.1f}s") + + # ===== 详细日志:AI 的思考过程和工具调用 ===== + try: + all_msgs = result.all_messages() + for idx, m in enumerate(all_msgs): + msg_kind = getattr(m, 'kind', type(m).__name__) + if hasattr(m, 'parts'): + for part in m.parts: + part_kind = getattr(part, 'part_kind', '') + if part_kind == 'tool-call': + tool_name = getattr(part, 'tool_name', '?') + tool_args = getattr(part, 'args', {}) + logger.info(f"[AI思考] 步骤{idx+1} 调用工具: {tool_name}({tool_args})") + elif part_kind == 'tool-return': + content = str(getattr(part, 'content', ''))[:200] + logger.info(f"[AI思考] 步骤{idx+1} 工具返回: {content}") + elif part_kind == 'text': + content = str(getattr(part, 'content', ''))[:150] + if content.strip(): + logger.info(f"[AI思考] 步骤{idx+1} 文本输出: {content}") + except Exception as log_err: + logger.debug(f"[AI思考日志] 解析失败: {log_err}") # --- 转接指令:直接从工具返回截获,不经过 AI 二次加工 --- transfer_cmd = "" diff --git a/db/chat_log_db.py b/db/chat_log_db.py index 75671f6..ca127ef 100755 --- a/db/chat_log_db.py +++ b/db/chat_log_db.py @@ -1,10 +1,13 @@ """ -聊天记录数据库(SQLite) +聊天记录数据库(SQLite / MySQL) 每条消息独立存储,按客户ID分开,支持查询和展示。 +支持 MySQL 连接池以提高性能。 """ import sqlite3 import os +import threading +from queue import Queue, Empty from datetime import datetime from typing import List, Dict, Optional @@ -16,6 +19,84 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") +# ========== MySQL 连接池 ========== +_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "50")) +_mysql_pool: Optional[Queue] = None +_pool_lock = threading.Lock() + + +def _create_mysql_conn(): + """创建单个 MySQL 连接""" + import pymysql + return pymysql.connect( + host=_MYSQL_HOST, + port=_MYSQL_PORT, + user=_MYSQL_USER, + password=_MYSQL_PASSWORD, + database=_MYSQL_DATABASE, + charset="utf8mb4", + cursorclass=pymysql.cursors.DictCursor, + autocommit=False, + connect_timeout=10, + read_timeout=30, + write_timeout=30, + ) + + +def _init_mysql_pool(): + """初始化 MySQL 连接池""" + global _mysql_pool + with _pool_lock: + if _mysql_pool is None: + _mysql_pool = Queue(maxsize=_POOL_SIZE) + for _ in range(_POOL_SIZE): + try: + conn = _create_mysql_conn() + _mysql_pool.put(conn) + except Exception: + pass # 启动时连接失败不阻塞,后续会重建 + + +def _get_pooled_conn(timeout: float = 5.0): + """从连接池获取连接""" + global _mysql_pool + if _mysql_pool is None: + _init_mysql_pool() + + try: + conn = _mysql_pool.get(timeout=timeout) + # 检查连接是否有效 + try: + conn.ping(reconnect=True) + except Exception: + # 连接失效,创建新连接 + try: + conn.close() + except Exception: + pass + conn = _create_mysql_conn() + return conn + except Empty: + # 池空了,创建新连接(不放回池) + return _create_mysql_conn() + + +def _return_conn(conn): + """归还连接到池""" + global _mysql_pool + if _mysql_pool is None: + return + try: + if _mysql_pool.qsize() < _POOL_SIZE: + _mysql_pool.put_nowait(conn) + else: + conn.close() + except Exception: + try: + conn.close() + except Exception: + pass + class _CompatResult: def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0): @@ -31,10 +112,11 @@ class _CompatResult: class _PyMySQLCompatConn: - """让 pymysql 连接兼容 sqlite 的 conn.execute 用法。""" + """让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池。""" - def __init__(self, conn): + def __init__(self, conn, use_pool: bool = True): self._conn = conn + self._use_pool = use_pool def __enter__(self): return self @@ -45,7 +127,11 @@ class _PyMySQLCompatConn: self._conn.rollback() except Exception: pass - self._conn.close() + # 归还连接到池而不是关闭 + if self._use_pool: + _return_conn(self._conn) + else: + self._conn.close() def execute(self, query: str, args=None): cur = self._conn.cursor() @@ -59,7 +145,10 @@ class _PyMySQLCompatConn: self._conn.commit() def close(self): - self._conn.close() + if self._use_pool: + _return_conn(self._conn) + else: + self._conn.close() def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") @@ -68,20 +157,22 @@ def _sql(query: str) -> str: return query.replace("?", "%s") if _is_mysql() else query -def _get_conn() -> sqlite3.Connection: +def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection: + """获取数据库连接,MySQL 使用连接池""" if _is_mysql(): - import pymysql - conn = pymysql.connect( - host=_MYSQL_HOST, - port=_MYSQL_PORT, - user=_MYSQL_USER, - password=_MYSQL_PASSWORD, - database=_MYSQL_DATABASE, - charset="utf8mb4", - cursorclass=pymysql.cursors.DictCursor, - autocommit=False, - ) - return _PyMySQLCompatConn(conn) + import time + last_error = None + for attempt in range(max_retries): + try: + conn = _get_pooled_conn(timeout=5.0) + return _PyMySQLCompatConn(conn, use_pool=True) + except Exception as e: + last_error = e + if attempt < max_retries - 1: + time.sleep(retry_delay * (attempt + 1)) + continue + raise + raise last_error os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True) conn = sqlite3.connect(_DB_PATH) conn.row_factory = sqlite3.Row @@ -192,8 +283,38 @@ def init_db(): init_db() +# ========== 重试装饰器 ========== + +def _retry_db_operation(func): + """数据库操作重试装饰器,处理连接丢失等临时错误""" + import functools + import time + @functools.wraps(func) + def wrapper(*args, **kwargs): + max_retries = 3 + last_error = None + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + err_str = str(e).lower() + # 判断是否为可重试的连接错误 + is_conn_error = any(k in err_str for k in [ + "lost connection", "gone away", "connection reset", + "can't connect", "connection refused", "2013", "2006" + ]) + if is_conn_error and attempt < max_retries - 1: + last_error = e + time.sleep(0.5 * (attempt + 1)) + continue + raise + raise last_error + return wrapper + + # ========== 写入 ========== +@_retry_db_operation def log_message( customer_id: str, message: str, @@ -251,6 +372,7 @@ def get_customers(limit: int = 100) -> List[Dict]: return [dict(r) for r in rows] +@_retry_db_operation def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]: """返回某客户的最近对话记录(按时间升序)""" # 忽略 acc_id 过滤,实现全店铺记忆 @@ -398,6 +520,7 @@ def get_latest_messages(limit: int = 20) -> List[Dict]: # ========== 订单相关 ========== +@_retry_db_operation def upsert_order( customer_id: str, order_id: str, @@ -431,6 +554,7 @@ def upsert_order( conn.commit() +@_retry_db_operation def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]: """查询某客户的订单记录(按时间倒序)""" with _get_conn() as conn: diff --git a/db/customer_db.py b/db/customer_db.py index 20ea2f1..b3e96b3 100755 --- a/db/customer_db.py +++ b/db/customer_db.py @@ -13,6 +13,9 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root") _MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "") _MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs") +# 复用 chat_log_db 的连接池 +from db.chat_log_db import _get_pooled_conn, _return_conn + def _is_mysql() -> bool: return _DB_TYPE in ("mysql", "mariadb") @@ -171,6 +174,23 @@ class CustomerProfile: self.image_analysis_history = [] +class _PooledMySQLConn: + """包装 pymysql 连接,支持连接池归还""" + def __init__(self, conn): + self._conn = conn + + def __enter__(self): + return self._conn + + def __exit__(self, exc_type, exc, tb): + if exc_type: + try: + self._conn.rollback() + except Exception: + pass + _return_conn(self._conn) + + class CustomerDatabase: """客户数据库""" @@ -180,17 +200,8 @@ class CustomerDatabase: self._ensure_db() def _get_mysql_conn(self): - import pymysql - return pymysql.connect( - host=_MYSQL_HOST, - port=_MYSQL_PORT, - user=_MYSQL_USER, - password=_MYSQL_PASSWORD, - database=_MYSQL_DATABASE, - charset="utf8mb4", - cursorclass=pymysql.cursors.DictCursor, - autocommit=False, - ) + """从连接池获取 MySQL 连接""" + return _PooledMySQLConn(_get_pooled_conn(timeout=5.0)) def _ensure_db(self): if _is_mysql(): @@ -284,24 +295,41 @@ class CustomerDatabase: data.pop('customer_id', None) return CustomerProfile(customer_id=customer_id, **data) - def save_customer(self, profile: CustomerProfile): + def save_customer(self, profile: CustomerProfile, max_retries: int = 3): + """保存客户画像(带重试机制)""" + import time profile.last_update = datetime.now().isoformat() if _is_mysql(): - with self._get_mysql_conn() as conn: - with conn.cursor() as cur: - cur.execute( - """ - REPLACE INTO customer_profiles (customer_id, profile_json, last_update) - VALUES (%s, %s, %s) - """, - ( - profile.customer_id, - json.dumps(asdict(profile), ensure_ascii=False), - datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - ), - ) - conn.commit() - return + last_error = None + for attempt in range(max_retries): + try: + with self._get_mysql_conn() as conn: + with conn.cursor() as cur: + cur.execute( + """ + REPLACE INTO customer_profiles (customer_id, profile_json, last_update) + VALUES (%s, %s, %s) + """, + ( + profile.customer_id, + json.dumps(asdict(profile), ensure_ascii=False), + datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + ), + ) + conn.commit() + return + except Exception as e: + last_error = e + err_str = str(e).lower() + is_conn_error = any(k in err_str for k in [ + "lost connection", "gone away", "connection reset", + "can't connect", "connection refused", "2013", "2006" + ]) + if is_conn_error and attempt < max_retries - 1: + time.sleep(0.5 * (attempt + 1)) + continue + raise + raise last_error customers = self._load_customers() customers[profile.customer_id] = asdict(profile) self._save_customers(customers) diff --git a/db/image_tasks.db b/db/image_tasks.db new file mode 100644 index 0000000..8d5bcff Binary files /dev/null and b/db/image_tasks.db differ