newtw66
This commit is contained in:
@@ -11,6 +11,13 @@ from db.chat_log_db import get_conversation, get_customer_orders
|
|||||||
logger = logging.getLogger("cs_agent")
|
logger = logging.getLogger("cs_agent")
|
||||||
|
|
||||||
|
|
||||||
|
class TransferSuccessException(Exception):
|
||||||
|
"""转接成功后抛出此异常,用于提前终止 AI 处理流程"""
|
||||||
|
def __init__(self, transfer_cmd: str):
|
||||||
|
self.transfer_cmd = transfer_cmd
|
||||||
|
super().__init__(transfer_cmd)
|
||||||
|
|
||||||
|
|
||||||
async def transfer_to_human_tool(ctx: RunContext[Any], reason: str = Field(description="转人工的原因")) -> str:
|
async def transfer_to_human_tool(ctx: RunContext[Any], reason: str = Field(description="转人工的原因")) -> str:
|
||||||
"""
|
"""
|
||||||
【核心工具】执行转人工逻辑。
|
【核心工具】执行转人工逻辑。
|
||||||
@@ -22,8 +29,9 @@ async def transfer_to_human_tool(ctx: RunContext[Any], reason: str = Field(descr
|
|||||||
|
|
||||||
if designer_name:
|
if designer_name:
|
||||||
magic_cmd = f"正在为您转接|[转移会话],{designer_name},无原因"
|
magic_cmd = f"正在为您转接|[转移会话],{designer_name},无原因"
|
||||||
logger.info(f"[Tool] 成功呼叫设计师: {designer_name}")
|
logger.info(f"[Tool] 成功呼叫设计师: {designer_name},立即触发转接")
|
||||||
return magic_cmd
|
# 抛出异常以提前终止 AI 后续处理,节省等待时间
|
||||||
|
raise TransferSuccessException(magic_cmd)
|
||||||
else:
|
else:
|
||||||
hour = datetime.now().hour
|
hour = datetime.now().hour
|
||||||
logger.warning(f"[Tool] 派单失败:设计师们不在位 (当前{hour}点)")
|
logger.warning(f"[Tool] 派单失败:设计师们不在位 (当前{hour}点)")
|
||||||
|
|||||||
@@ -37,6 +37,18 @@ _OUTBOUND_BLOCK_MARKERS = (
|
|||||||
'[{"name":',
|
'[{"name":',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 历史记录格式检测模式(AI 转述历史时容易泄露)
|
||||||
|
_HISTORY_LEAK_PATTERNS = [
|
||||||
|
r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*(客户|客服)[::]', # [2026-03-07 12:00:00] 客户:
|
||||||
|
r'\[\d{2}:\d{2}:\d{2}\]\s*(客户|客服|我)[::]', # [12:00:00] 客户:
|
||||||
|
r'(根据|查看|查询|翻看)(历史|聊天|对话)(记录|内容)', # 根据历史记录
|
||||||
|
r'历史(记录|对话|消息)(显示|表明|中)', # 历史记录显示
|
||||||
|
r'之前的(聊天|对话|记录)(中|里|显示)', # 之前的聊天中
|
||||||
|
r'共\d+条(历史|对话)?消息', # 共30条历史消息
|
||||||
|
r'订单号[::]\s*\d{10,}', # 订单号:xxxxxxxxxx
|
||||||
|
r'(状态|金额|数量)[::].*(状态|金额|数量)[::]', # 状态:xxx 金额:xxx 连续出现
|
||||||
|
]
|
||||||
|
|
||||||
class SystemOrchestrator:
|
class SystemOrchestrator:
|
||||||
"""
|
"""
|
||||||
全系统总编排:具备转接冷却、防抖合并、多消息去重、以及精准日志。
|
全系统总编排:具备转接冷却、防抖合并、多消息去重、以及精准日志。
|
||||||
@@ -84,6 +96,11 @@ class SystemOrchestrator:
|
|||||||
if any(marker in cleaned for marker in _OUTBOUND_BLOCK_MARKERS):
|
if any(marker in cleaned for marker in _OUTBOUND_BLOCK_MARKERS):
|
||||||
logger.warning("[Orchestrator] 拦截到内部内容外发,替换为安全兜底回复")
|
logger.warning("[Orchestrator] 拦截到内部内容外发,替换为安全兜底回复")
|
||||||
return "我在帮你看记录,稍等哈"
|
return "我在帮你看记录,稍等哈"
|
||||||
|
# 检查历史记录泄露模式
|
||||||
|
for pattern in _HISTORY_LEAK_PATTERNS:
|
||||||
|
if re.search(pattern, cleaned):
|
||||||
|
logger.warning(f"[Orchestrator] 检测到历史记录泄露模式: {pattern[:30]}...")
|
||||||
|
return "我在帮你看记录,稍等哈"
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
async def on_raw_message_received(self, platform: str, raw_data: dict):
|
async def on_raw_message_received(self, platform: str, raw_data: dict):
|
||||||
@@ -243,11 +260,17 @@ class SystemOrchestrator:
|
|||||||
|
|
||||||
async def _debounced_process(self, session_key: str, user_id: str, platform: str):
|
async def _debounced_process(self, session_key: str, user_id: str, platform: str):
|
||||||
try:
|
try:
|
||||||
|
# 记录开始时间(防抖前)
|
||||||
|
process_start = time.time()
|
||||||
|
|
||||||
await asyncio.sleep(self._debounce_seconds)
|
await asyncio.sleep(self._debounce_seconds)
|
||||||
async with self._get_user_lock(session_key):
|
async with self._get_user_lock(session_key):
|
||||||
messages = self._pending_messages.pop(session_key, [])
|
messages = self._pending_messages.pop(session_key, [])
|
||||||
if not messages: return
|
if not messages: return
|
||||||
|
|
||||||
|
debounce_elapsed = time.time() - process_start
|
||||||
|
logger.info(f"[计时] user={user_id} 防抖等待完成: {debounce_elapsed:.1f}s")
|
||||||
|
|
||||||
# A. 合并与元数据修复(去重:同一防抖窗口内完全相同的内容只保留一条)
|
# A. 合并与元数据修复(去重:同一防抖窗口内完全相同的内容只保留一条)
|
||||||
seen_contents = set()
|
seen_contents = set()
|
||||||
unique_parts = []
|
unique_parts = []
|
||||||
@@ -278,9 +301,12 @@ class SystemOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# B. 持久化
|
# B. 持久化
|
||||||
|
db_start = time.time()
|
||||||
db_content = combined_content
|
db_content = combined_content
|
||||||
if all_image_urls: db_content = f"【系统:已收到{len(all_image_urls)}张图】\n{combined_content}"
|
if all_image_urls: db_content = f"【系统:已收到{len(all_image_urls)}张图】\n{combined_content}"
|
||||||
await repo.save_chat(platform, user_id, db_content, "in", acc_id=acc_id, image_urls=all_image_urls)
|
await repo.save_chat(platform, user_id, db_content, "in", acc_id=acc_id, image_urls=all_image_urls)
|
||||||
|
db_elapsed = time.time() - db_start
|
||||||
|
logger.info(f"[计时] user={user_id} 消息入库: {db_elapsed:.2f}s")
|
||||||
|
|
||||||
# B2. 后台图片分析(不阻塞主流程,用于数据标定)
|
# B2. 后台图片分析(不阻塞主流程,用于数据标定)
|
||||||
if all_image_urls:
|
if all_image_urls:
|
||||||
@@ -302,9 +328,17 @@ class SystemOrchestrator:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# D. 正常流程:调用AI思考
|
# D. 正常流程:调用AI思考
|
||||||
|
history_start = time.time()
|
||||||
history = await repo.get_chat_history(user_id, limit=10, acc_id=acc_id)
|
history = await repo.get_chat_history(user_id, limit=10, acc_id=acc_id)
|
||||||
if history and history[-1].get('content') == db_content: history = history[:-1]
|
if history and history[-1].get('content') == db_content: history = history[:-1]
|
||||||
|
history_elapsed = time.time() - history_start
|
||||||
|
logger.info(f"[计时] user={user_id} 查询历史: {history_elapsed:.2f}s (共{len(history)}条)")
|
||||||
|
|
||||||
|
ai_start = time.time()
|
||||||
std_res = await self.brain.think_and_reply(final_msg, history=history)
|
std_res = await self.brain.think_and_reply(final_msg, history=history)
|
||||||
|
ai_elapsed = time.time() - ai_start
|
||||||
|
total_elapsed = time.time() - process_start
|
||||||
|
logger.info(f"[计时] user={user_id} AI思考: {ai_elapsed:.1f}s | 总耗时: {total_elapsed:.1f}s")
|
||||||
|
|
||||||
# E. 发送并记录时间
|
# E. 发送并记录时间
|
||||||
if std_res.should_reply:
|
if std_res.should_reply:
|
||||||
|
|||||||
@@ -2,15 +2,20 @@ import os
|
|||||||
import re
|
import re
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import List, Optional, Any, Dict
|
from typing import List, Optional, Any, Dict
|
||||||
from pydantic_ai import Agent, RunContext
|
from pydantic_ai import Agent, RunContext
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
from core.schema import StandardMessage, StandardResponse
|
from core.schema import StandardMessage, StandardResponse
|
||||||
from core.agent_tools import register_agent_tools
|
from core.agent_tools import register_agent_tools, TransferSuccessException
|
||||||
|
|
||||||
logger = logging.getLogger("cs_agent")
|
logger = logging.getLogger("cs_agent")
|
||||||
|
|
||||||
|
# 日志详细程度:设置环境变量 AI_LOG_LEVEL=debug 可获得完整日志
|
||||||
|
_LOG_FULL_PROMPT = os.getenv("AI_LOG_LEVEL", "").lower() == "debug"
|
||||||
|
_LOG_CLIP_LIMIT = int(os.getenv("AI_LOG_CLIP", "2000")) # 日志截断长度
|
||||||
|
|
||||||
from core.skill_manager import skill_manager
|
from core.skill_manager import skill_manager
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +26,18 @@ _INTERNAL_TOOL_MARKERS = (
|
|||||||
"【订单详情】",
|
"【订单详情】",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 历史记录格式检测模式(AI 转述历史时容易泄露)
|
||||||
|
_HISTORY_LEAK_PATTERNS = [
|
||||||
|
r'\[\d{4}-\d{2}-\d{2}[^\]]*\]\s*(客户|客服)[::]', # [2026-03-07 12:00:00] 客户:
|
||||||
|
r'\[\d{2}:\d{2}:\d{2}\]\s*(客户|客服|我)[::]', # [12:00:00] 客户:
|
||||||
|
r'(根据|查看|查询|翻看)(历史|聊天|对话)(记录|内容)', # 根据历史记录
|
||||||
|
r'历史(记录|对话|消息)(显示|表明|中)', # 历史记录显示
|
||||||
|
r'之前的(聊天|对话|记录)(中|里|显示)', # 之前的聊天中
|
||||||
|
r'共\d+条(历史|对话)?消息', # 共30条历史消息
|
||||||
|
r'订单号[::]\s*\d{10,}', # 订单号:xxxxxxxxxx
|
||||||
|
r'(状态|金额|数量)[::].*(状态|金额|数量)[::]', # 状态:xxx 金额:xxx 连续出现
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _clip(text: str, limit: int = 1200) -> str:
|
def _clip(text: str, limit: int = 1200) -> str:
|
||||||
if text is None:
|
if text is None:
|
||||||
@@ -61,10 +78,17 @@ def _sanitize_reply_text(reply_text: str) -> str:
|
|||||||
text = re.sub(r'[\[\]]{2,}', '', text)
|
text = re.sub(r'[\[\]]{2,}', '', text)
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|
||||||
|
# 检查固定标记
|
||||||
if any(marker in text for marker in _INTERNAL_TOOL_MARKERS):
|
if any(marker in text for marker in _INTERNAL_TOOL_MARKERS):
|
||||||
logger.warning("[Brain] 拦截到工具原文泄露,降级为安全兜底回复")
|
logger.warning("[Brain] 拦截到工具原文泄露,降级为安全兜底回复")
|
||||||
return "我在帮你看记录,稍等哈"
|
return "我在帮你看记录,稍等哈"
|
||||||
|
|
||||||
|
# 检查历史记录泄露模式(AI 转述历史内容)
|
||||||
|
for pattern in _HISTORY_LEAK_PATTERNS:
|
||||||
|
if re.search(pattern, text):
|
||||||
|
logger.warning(f"[Brain] 检测到历史记录泄露模式: {pattern[:30]}...")
|
||||||
|
return "我在帮你看记录,稍等哈"
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
@@ -187,12 +211,51 @@ class CustomerServiceBrain:
|
|||||||
recent_context = "【近期对话回顾】\n" + "\n".join(lines) + "\n----------------\n"
|
recent_context = "【近期对话回顾】\n" + "\n".join(lines) + "\n----------------\n"
|
||||||
|
|
||||||
full_input = f"【当前客户ID:{msg.user_id}】\n{recent_context}现在的对话:{user_content}"
|
full_input = f"【当前客户ID:{msg.user_id}】\n{recent_context}现在的对话:{user_content}"
|
||||||
logger.info(
|
start_time = time.time()
|
||||||
f"[PROMPT->AI] user={msg.user_id} acc={msg.acc_id} images={len(msg.image_urls)}\n"
|
|
||||||
f"{_clip(full_input)}"
|
# ===== 详细日志:发给 AI 的提示词 =====
|
||||||
)
|
logger.info(f"[AI提示词] user={msg.user_id} acc={msg.acc_id} images={len(msg.image_urls)}\n{full_input}")
|
||||||
|
if history:
|
||||||
|
history_preview = "\n".join([f" {h.get('role','?')}: {str(h.get('content',''))[:50]}" for h in history[-4:]])
|
||||||
|
logger.info(f"[AI历史上下文] 共{len(history)}条:\n{history_preview}")
|
||||||
|
|
||||||
result = await self.agent.run(full_input, message_history=history)
|
# 尝试运行 AI,捕获转接成功异常以提前终止
|
||||||
|
try:
|
||||||
|
result = await self.agent.run(full_input, message_history=history)
|
||||||
|
except TransferSuccessException as e:
|
||||||
|
# 转接工具成功后立即返回,无需等待 AI 继续生成
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(f"[Brain] 转接成功(提前终止,耗时{elapsed:.1f}s): {e.transfer_cmd[:60]}")
|
||||||
|
return StandardResponse(
|
||||||
|
reply_content=e.transfer_cmd,
|
||||||
|
need_transfer=True,
|
||||||
|
metadata={"acc_id": msg.acc_id, "acc_type": msg.acc_type}
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(f"[Brain] AI处理完成,总耗时{elapsed:.1f}s")
|
||||||
|
|
||||||
|
# ===== 详细日志:AI 的思考过程和工具调用 =====
|
||||||
|
try:
|
||||||
|
all_msgs = result.all_messages()
|
||||||
|
for idx, m in enumerate(all_msgs):
|
||||||
|
msg_kind = getattr(m, 'kind', type(m).__name__)
|
||||||
|
if hasattr(m, 'parts'):
|
||||||
|
for part in m.parts:
|
||||||
|
part_kind = getattr(part, 'part_kind', '')
|
||||||
|
if part_kind == 'tool-call':
|
||||||
|
tool_name = getattr(part, 'tool_name', '?')
|
||||||
|
tool_args = getattr(part, 'args', {})
|
||||||
|
logger.info(f"[AI思考] 步骤{idx+1} 调用工具: {tool_name}({tool_args})")
|
||||||
|
elif part_kind == 'tool-return':
|
||||||
|
content = str(getattr(part, 'content', ''))[:200]
|
||||||
|
logger.info(f"[AI思考] 步骤{idx+1} 工具返回: {content}")
|
||||||
|
elif part_kind == 'text':
|
||||||
|
content = str(getattr(part, 'content', ''))[:150]
|
||||||
|
if content.strip():
|
||||||
|
logger.info(f"[AI思考] 步骤{idx+1} 文本输出: {content}")
|
||||||
|
except Exception as log_err:
|
||||||
|
logger.debug(f"[AI思考日志] 解析失败: {log_err}")
|
||||||
|
|
||||||
# --- 转接指令:直接从工具返回截获,不经过 AI 二次加工 ---
|
# --- 转接指令:直接从工具返回截获,不经过 AI 二次加工 ---
|
||||||
transfer_cmd = ""
|
transfer_cmd = ""
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
聊天记录数据库(SQLite)
|
聊天记录数据库(SQLite / MySQL)
|
||||||
每条消息独立存储,按客户ID分开,支持查询和展示。
|
每条消息独立存储,按客户ID分开,支持查询和展示。
|
||||||
|
支持 MySQL 连接池以提高性能。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
|
from queue import Queue, Empty
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
@@ -16,6 +19,84 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
|||||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||||
|
|
||||||
|
# ========== MySQL 连接池 ==========
|
||||||
|
_POOL_SIZE = int(os.getenv("MYSQL_POOL_SIZE", "50"))
|
||||||
|
_mysql_pool: Optional[Queue] = None
|
||||||
|
_pool_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mysql_conn():
|
||||||
|
"""创建单个 MySQL 连接"""
|
||||||
|
import pymysql
|
||||||
|
return pymysql.connect(
|
||||||
|
host=_MYSQL_HOST,
|
||||||
|
port=_MYSQL_PORT,
|
||||||
|
user=_MYSQL_USER,
|
||||||
|
password=_MYSQL_PASSWORD,
|
||||||
|
database=_MYSQL_DATABASE,
|
||||||
|
charset="utf8mb4",
|
||||||
|
cursorclass=pymysql.cursors.DictCursor,
|
||||||
|
autocommit=False,
|
||||||
|
connect_timeout=10,
|
||||||
|
read_timeout=30,
|
||||||
|
write_timeout=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_mysql_pool():
|
||||||
|
"""初始化 MySQL 连接池"""
|
||||||
|
global _mysql_pool
|
||||||
|
with _pool_lock:
|
||||||
|
if _mysql_pool is None:
|
||||||
|
_mysql_pool = Queue(maxsize=_POOL_SIZE)
|
||||||
|
for _ in range(_POOL_SIZE):
|
||||||
|
try:
|
||||||
|
conn = _create_mysql_conn()
|
||||||
|
_mysql_pool.put(conn)
|
||||||
|
except Exception:
|
||||||
|
pass # 启动时连接失败不阻塞,后续会重建
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pooled_conn(timeout: float = 5.0):
|
||||||
|
"""从连接池获取连接"""
|
||||||
|
global _mysql_pool
|
||||||
|
if _mysql_pool is None:
|
||||||
|
_init_mysql_pool()
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = _mysql_pool.get(timeout=timeout)
|
||||||
|
# 检查连接是否有效
|
||||||
|
try:
|
||||||
|
conn.ping(reconnect=True)
|
||||||
|
except Exception:
|
||||||
|
# 连接失效,创建新连接
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
conn = _create_mysql_conn()
|
||||||
|
return conn
|
||||||
|
except Empty:
|
||||||
|
# 池空了,创建新连接(不放回池)
|
||||||
|
return _create_mysql_conn()
|
||||||
|
|
||||||
|
|
||||||
|
def _return_conn(conn):
|
||||||
|
"""归还连接到池"""
|
||||||
|
global _mysql_pool
|
||||||
|
if _mysql_pool is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
if _mysql_pool.qsize() < _POOL_SIZE:
|
||||||
|
_mysql_pool.put_nowait(conn)
|
||||||
|
else:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _CompatResult:
|
class _CompatResult:
|
||||||
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
def __init__(self, rows=None, rowcount: int = 0, lastrowid: int = 0):
|
||||||
@@ -31,10 +112,11 @@ class _CompatResult:
|
|||||||
|
|
||||||
|
|
||||||
class _PyMySQLCompatConn:
|
class _PyMySQLCompatConn:
|
||||||
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法。"""
|
"""让 pymysql 连接兼容 sqlite 的 conn.execute 用法,支持连接池。"""
|
||||||
|
|
||||||
def __init__(self, conn):
|
def __init__(self, conn, use_pool: bool = True):
|
||||||
self._conn = conn
|
self._conn = conn
|
||||||
|
self._use_pool = use_pool
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@@ -45,7 +127,11 @@ class _PyMySQLCompatConn:
|
|||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._conn.close()
|
# 归还连接到池而不是关闭
|
||||||
|
if self._use_pool:
|
||||||
|
_return_conn(self._conn)
|
||||||
|
else:
|
||||||
|
self._conn.close()
|
||||||
|
|
||||||
def execute(self, query: str, args=None):
|
def execute(self, query: str, args=None):
|
||||||
cur = self._conn.cursor()
|
cur = self._conn.cursor()
|
||||||
@@ -59,7 +145,10 @@ class _PyMySQLCompatConn:
|
|||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._conn.close()
|
if self._use_pool:
|
||||||
|
_return_conn(self._conn)
|
||||||
|
else:
|
||||||
|
self._conn.close()
|
||||||
|
|
||||||
def _is_mysql() -> bool:
|
def _is_mysql() -> bool:
|
||||||
return _DB_TYPE in ("mysql", "mariadb")
|
return _DB_TYPE in ("mysql", "mariadb")
|
||||||
@@ -68,20 +157,22 @@ def _sql(query: str) -> str:
|
|||||||
return query.replace("?", "%s") if _is_mysql() else query
|
return query.replace("?", "%s") if _is_mysql() else query
|
||||||
|
|
||||||
|
|
||||||
def _get_conn() -> sqlite3.Connection:
|
def _get_conn(max_retries: int = 3, retry_delay: float = 0.5) -> sqlite3.Connection:
|
||||||
|
"""获取数据库连接,MySQL 使用连接池"""
|
||||||
if _is_mysql():
|
if _is_mysql():
|
||||||
import pymysql
|
import time
|
||||||
conn = pymysql.connect(
|
last_error = None
|
||||||
host=_MYSQL_HOST,
|
for attempt in range(max_retries):
|
||||||
port=_MYSQL_PORT,
|
try:
|
||||||
user=_MYSQL_USER,
|
conn = _get_pooled_conn(timeout=5.0)
|
||||||
password=_MYSQL_PASSWORD,
|
return _PyMySQLCompatConn(conn, use_pool=True)
|
||||||
database=_MYSQL_DATABASE,
|
except Exception as e:
|
||||||
charset="utf8mb4",
|
last_error = e
|
||||||
cursorclass=pymysql.cursors.DictCursor,
|
if attempt < max_retries - 1:
|
||||||
autocommit=False,
|
time.sleep(retry_delay * (attempt + 1))
|
||||||
)
|
continue
|
||||||
return _PyMySQLCompatConn(conn)
|
raise
|
||||||
|
raise last_error
|
||||||
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
os.makedirs(os.path.dirname(_DB_PATH), exist_ok=True)
|
||||||
conn = sqlite3.connect(_DB_PATH)
|
conn = sqlite3.connect(_DB_PATH)
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
@@ -192,8 +283,38 @@ def init_db():
|
|||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 重试装饰器 ==========
|
||||||
|
|
||||||
|
def _retry_db_operation(func):
|
||||||
|
"""数据库操作重试装饰器,处理连接丢失等临时错误"""
|
||||||
|
import functools
|
||||||
|
import time
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
max_retries = 3
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
err_str = str(e).lower()
|
||||||
|
# 判断是否为可重试的连接错误
|
||||||
|
is_conn_error = any(k in err_str for k in [
|
||||||
|
"lost connection", "gone away", "connection reset",
|
||||||
|
"can't connect", "connection refused", "2013", "2006"
|
||||||
|
])
|
||||||
|
if is_conn_error and attempt < max_retries - 1:
|
||||||
|
last_error = e
|
||||||
|
time.sleep(0.5 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
raise last_error
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# ========== 写入 ==========
|
# ========== 写入 ==========
|
||||||
|
|
||||||
|
@_retry_db_operation
|
||||||
def log_message(
|
def log_message(
|
||||||
customer_id: str,
|
customer_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
@@ -251,6 +372,7 @@ def get_customers(limit: int = 100) -> List[Dict]:
|
|||||||
return [dict(r) for r in rows]
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@_retry_db_operation
|
||||||
def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]:
|
def get_conversation(customer_id: str, limit: int = 200, acc_id: str = "") -> List[Dict]:
|
||||||
"""返回某客户的最近对话记录(按时间升序)"""
|
"""返回某客户的最近对话记录(按时间升序)"""
|
||||||
# 忽略 acc_id 过滤,实现全店铺记忆
|
# 忽略 acc_id 过滤,实现全店铺记忆
|
||||||
@@ -398,6 +520,7 @@ def get_latest_messages(limit: int = 20) -> List[Dict]:
|
|||||||
|
|
||||||
# ========== 订单相关 ==========
|
# ========== 订单相关 ==========
|
||||||
|
|
||||||
|
@_retry_db_operation
|
||||||
def upsert_order(
|
def upsert_order(
|
||||||
customer_id: str,
|
customer_id: str,
|
||||||
order_id: str,
|
order_id: str,
|
||||||
@@ -431,6 +554,7 @@ def upsert_order(
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@_retry_db_operation
|
||||||
def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]:
|
def get_customer_orders(customer_id: str, limit: int = 10) -> List[Dict]:
|
||||||
"""查询某客户的订单记录(按时间倒序)"""
|
"""查询某客户的订单记录(按时间倒序)"""
|
||||||
with _get_conn() as conn:
|
with _get_conn() as conn:
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ _MYSQL_USER = os.getenv("MYSQL_USER", "root")
|
|||||||
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
_MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "")
|
||||||
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
_MYSQL_DATABASE = os.getenv("MYSQL_DATABASE", "ai_cs")
|
||||||
|
|
||||||
|
# 复用 chat_log_db 的连接池
|
||||||
|
from db.chat_log_db import _get_pooled_conn, _return_conn
|
||||||
|
|
||||||
|
|
||||||
def _is_mysql() -> bool:
|
def _is_mysql() -> bool:
|
||||||
return _DB_TYPE in ("mysql", "mariadb")
|
return _DB_TYPE in ("mysql", "mariadb")
|
||||||
@@ -171,6 +174,23 @@ class CustomerProfile:
|
|||||||
self.image_analysis_history = []
|
self.image_analysis_history = []
|
||||||
|
|
||||||
|
|
||||||
|
class _PooledMySQLConn:
|
||||||
|
"""包装 pymysql 连接,支持连接池归还"""
|
||||||
|
def __init__(self, conn):
|
||||||
|
self._conn = conn
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
if exc_type:
|
||||||
|
try:
|
||||||
|
self._conn.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_return_conn(self._conn)
|
||||||
|
|
||||||
|
|
||||||
class CustomerDatabase:
|
class CustomerDatabase:
|
||||||
"""客户数据库"""
|
"""客户数据库"""
|
||||||
|
|
||||||
@@ -180,17 +200,8 @@ class CustomerDatabase:
|
|||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
|
|
||||||
def _get_mysql_conn(self):
|
def _get_mysql_conn(self):
|
||||||
import pymysql
|
"""从连接池获取 MySQL 连接"""
|
||||||
return pymysql.connect(
|
return _PooledMySQLConn(_get_pooled_conn(timeout=5.0))
|
||||||
host=_MYSQL_HOST,
|
|
||||||
port=_MYSQL_PORT,
|
|
||||||
user=_MYSQL_USER,
|
|
||||||
password=_MYSQL_PASSWORD,
|
|
||||||
database=_MYSQL_DATABASE,
|
|
||||||
charset="utf8mb4",
|
|
||||||
cursorclass=pymysql.cursors.DictCursor,
|
|
||||||
autocommit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _ensure_db(self):
|
def _ensure_db(self):
|
||||||
if _is_mysql():
|
if _is_mysql():
|
||||||
@@ -284,24 +295,41 @@ class CustomerDatabase:
|
|||||||
data.pop('customer_id', None)
|
data.pop('customer_id', None)
|
||||||
return CustomerProfile(customer_id=customer_id, **data)
|
return CustomerProfile(customer_id=customer_id, **data)
|
||||||
|
|
||||||
def save_customer(self, profile: CustomerProfile):
|
def save_customer(self, profile: CustomerProfile, max_retries: int = 3):
|
||||||
|
"""保存客户画像(带重试机制)"""
|
||||||
|
import time
|
||||||
profile.last_update = datetime.now().isoformat()
|
profile.last_update = datetime.now().isoformat()
|
||||||
if _is_mysql():
|
if _is_mysql():
|
||||||
with self._get_mysql_conn() as conn:
|
last_error = None
|
||||||
with conn.cursor() as cur:
|
for attempt in range(max_retries):
|
||||||
cur.execute(
|
try:
|
||||||
"""
|
with self._get_mysql_conn() as conn:
|
||||||
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
with conn.cursor() as cur:
|
||||||
VALUES (%s, %s, %s)
|
cur.execute(
|
||||||
""",
|
"""
|
||||||
(
|
REPLACE INTO customer_profiles (customer_id, profile_json, last_update)
|
||||||
profile.customer_id,
|
VALUES (%s, %s, %s)
|
||||||
json.dumps(asdict(profile), ensure_ascii=False),
|
""",
|
||||||
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
(
|
||||||
),
|
profile.customer_id,
|
||||||
)
|
json.dumps(asdict(profile), ensure_ascii=False),
|
||||||
conn.commit()
|
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
return
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
err_str = str(e).lower()
|
||||||
|
is_conn_error = any(k in err_str for k in [
|
||||||
|
"lost connection", "gone away", "connection reset",
|
||||||
|
"can't connect", "connection refused", "2013", "2006"
|
||||||
|
])
|
||||||
|
if is_conn_error and attempt < max_retries - 1:
|
||||||
|
time.sleep(0.5 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
raise last_error
|
||||||
customers = self._load_customers()
|
customers = self._load_customers()
|
||||||
customers[profile.customer_id] = asdict(profile)
|
customers[profile.customer_id] = asdict(profile)
|
||||||
self._save_customers(customers)
|
self._save_customers(customers)
|
||||||
|
|||||||
BIN
db/image_tasks.db
Normal file
BIN
db/image_tasks.db
Normal file
Binary file not shown.
Reference in New Issue
Block a user