chore: harden runtime checks and split websocket inbound/outbound flows
This commit is contained in:
20
.gitignore
vendored
Normal file
20
.gitignore
vendored
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# Python cache/artifacts
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# Tool caches
|
||||||
|
.ruff_cache/
|
||||||
|
.pytest_cache/
|
||||||
|
.uv_cache/
|
||||||
|
.uv-cache/
|
||||||
|
.uv-cache-test/
|
||||||
|
|
||||||
|
# Runtime artifacts
|
||||||
|
logs/*.log
|
||||||
|
config/.runtime_metrics.jsonl
|
||||||
|
|
||||||
123
core/websocket_inbound_flow.py
Normal file
123
core/websocket_inbound_flow.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("cs_agent")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_incoming_message(client, message: str, *, shop_type_resolver):
|
||||||
|
"""处理单条入站消息(从 websocket_client.py 拆出)。"""
|
||||||
|
timestamp = client.get_time()
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
# 多进程分片检查:确保同一客户只由一个 worker 处理
|
||||||
|
customer_key = client._customer_key(data)
|
||||||
|
if not client._is_owned_by_this_worker(customer_key):
|
||||||
|
return
|
||||||
|
|
||||||
|
timestamp = client.get_time()
|
||||||
|
|
||||||
|
# 保存最后一条消息用于回复
|
||||||
|
client.last_msg = data
|
||||||
|
|
||||||
|
# 打印格式化的消息
|
||||||
|
logger.info(f"\n{'='*50}")
|
||||||
|
logger.info(f"[{timestamp}] 收到新消息:")
|
||||||
|
logger.info(f"{'='*50}")
|
||||||
|
logger.info(f" 消息ID: {data.get('msg_id', 'N/A')}")
|
||||||
|
logger.info(f" 账号ID: {client.to_chinese(data.get('acc_id', 'N/A'))}")
|
||||||
|
logger.info(f" 发送者ID: {client.to_chinese(data.get('from_id', 'N/A'))}")
|
||||||
|
logger.info(f" 发送者名称: {client.to_chinese(data.get('from_name', 'N/A'))}")
|
||||||
|
logger.info(f" 会话ID: {client.to_chinese(data.get('cy_id', 'N/A'))}")
|
||||||
|
logger.info(f" 平台类型: {data.get('acc_type', 'N/A')}")
|
||||||
|
logger.info(f" 消息类型: {client.get_msg_type_name(data.get('msg_type', 0))}")
|
||||||
|
logger.info(f" 消息内容: {client.to_chinese(data.get('msg', 'N/A'))}")
|
||||||
|
|
||||||
|
# 显示商品信息(如果有)
|
||||||
|
if data.get('goods_name'):
|
||||||
|
logger.info(f" 商品名称: {client.to_chinese(data.get('goods_name', ''))}")
|
||||||
|
if data.get('goods_order'):
|
||||||
|
logger.info(f" 订单信息: {client.to_chinese(data.get('goods_order', ''))}")
|
||||||
|
|
||||||
|
logger.info(f"{'='*50}\n")
|
||||||
|
|
||||||
|
# 消息去重:同一条消息不重复处理
|
||||||
|
msg_id = data.get('msg_id', '')
|
||||||
|
if msg_id and msg_id in client._replied_msg_ids:
|
||||||
|
logger.info(f"重复消息,跳过: {msg_id}")
|
||||||
|
return
|
||||||
|
if msg_id:
|
||||||
|
client._replied_msg_ids.append(msg_id) # deque 自动淘汰最旧的
|
||||||
|
|
||||||
|
# 空消息/无效消息过滤(N/A 或关键字段全为空)
|
||||||
|
from_id = data.get('from_id', '')
|
||||||
|
acc_id = data.get('acc_id', '')
|
||||||
|
if not from_id or from_id == 'N/A' or not acc_id or acc_id == 'N/A':
|
||||||
|
logger.info(f"[{client.get_time()}] 空消息跳过(from_id={from_id!r} acc_id={acc_id!r})")
|
||||||
|
return
|
||||||
|
client._log_inbound_once(data)
|
||||||
|
client._fire_and_forget(client._post_tianwang_callback("message_received", data))
|
||||||
|
|
||||||
|
# Gemini 店铺:不回复,直接跳过
|
||||||
|
goods_name = client.to_chinese(data.get('goods_name', '') or '')
|
||||||
|
if shop_type_resolver(acc_id, goods_name) == "gemini_api":
|
||||||
|
logger.info(f"[{client.get_time()}] Gemini 店铺消息,跳过")
|
||||||
|
client._push_chat_to_wechat_safe(
|
||||||
|
data=data,
|
||||||
|
customer_msg=data.get('msg', ''),
|
||||||
|
reply_msg="",
|
||||||
|
goods_name=goods_name,
|
||||||
|
tag="gemini店铺跳过",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 使用 Agent 自动回复(仅处理文本消息)
|
||||||
|
if client.enable_agent:
|
||||||
|
msg_type = data.get('msg_type', 0)
|
||||||
|
if msg_type == 0:
|
||||||
|
if client._is_transfer_msg(data):
|
||||||
|
# 会话转交 → 主动打招呼
|
||||||
|
logger.info(f"[{client.get_time()}] 收到转交消息,发送问候")
|
||||||
|
greeting = client._pick_transfer_greeting()
|
||||||
|
await client.send_reply(data, greeting)
|
||||||
|
client._push_chat_to_wechat_safe(
|
||||||
|
data=data,
|
||||||
|
customer_msg=data.get('msg', ''),
|
||||||
|
reply_msg=greeting,
|
||||||
|
tag="转交问候",
|
||||||
|
)
|
||||||
|
elif client._is_shop_card(data):
|
||||||
|
# 进店卡片:有历史对话就不回复,没有才打招呼(Gemini 已在上面统一跳过)
|
||||||
|
cid = data.get('from_id', '')
|
||||||
|
acc_id = data.get('acc_id', '')
|
||||||
|
residual_text = client._extract_customer_text_from_shop_card_msg(data.get('msg', ''))
|
||||||
|
if residual_text:
|
||||||
|
logger.info(f"[{client.get_time()}] 进店卡片携带客户文本,转普通消息处理: {residual_text}")
|
||||||
|
patched = dict(data)
|
||||||
|
patched['msg'] = residual_text
|
||||||
|
await client._debounce_agent_reply(patched)
|
||||||
|
elif client._has_chat_history(cid, acc_id=acc_id):
|
||||||
|
logger.info(f"[{client.get_time()}] 进店卡片(已有记录),跳过")
|
||||||
|
else:
|
||||||
|
logger.info(f"[{client.get_time()}] 进店卡片(新客户),发送问候")
|
||||||
|
greeting = "在呢,发图来我看看"
|
||||||
|
await client.send_reply(data, greeting)
|
||||||
|
client._push_chat_to_wechat_safe(
|
||||||
|
data=data,
|
||||||
|
customer_msg=data.get('msg', ''),
|
||||||
|
reply_msg=greeting,
|
||||||
|
goods_name=goods_name,
|
||||||
|
tag="进店卡片问候",
|
||||||
|
)
|
||||||
|
elif await client._handle_system_inquiry(data):
|
||||||
|
logger.info(f"[{client.get_time()}] 系统客服询单消息,已按规则处理")
|
||||||
|
elif client._should_ignore(data):
|
||||||
|
logger.info(f"[{client.get_time()}] 系统通知,跳过回复")
|
||||||
|
else:
|
||||||
|
await client._debounce_agent_reply(data)
|
||||||
|
elif msg_type == 1:
|
||||||
|
# 图片消息直接处理,不走防抖(图片不会连续多发)
|
||||||
|
await client.handle_image_message(data)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.info(f"[{timestamp}] 收到非JSON消息: {message}")
|
||||||
285
core/websocket_outbound_flow.py
Normal file
285
core/websocket_outbound_flow.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
async def send_reply_flow(client, original_msg: dict, reply_content: str):
|
||||||
|
"""
|
||||||
|
发送回复消息(从 websocket_client.py 拆出)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_msg: 收到的原始消息字典
|
||||||
|
reply_content: 回复内容(文本或本地文件路径/http地址)
|
||||||
|
"""
|
||||||
|
trace_id = original_msg.get("_trace_id", "")
|
||||||
|
if not client.websocket:
|
||||||
|
client._activity_log(
|
||||||
|
"send_reply_skipped",
|
||||||
|
trace_id=trace_id,
|
||||||
|
reason="websocket_not_connected",
|
||||||
|
acc_id=original_msg.get("acc_id", ""),
|
||||||
|
customer_id=original_msg.get("from_id", ""),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
reply_content = colloquialize_outbound_reply(reply_content)
|
||||||
|
reply_content = await ai_generate_outbound_reply(
|
||||||
|
client=client,
|
||||||
|
original_msg=original_msg,
|
||||||
|
reply_content=str(reply_content or ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 同一客户外发限流:N 秒内最多 1 条
|
||||||
|
try:
|
||||||
|
from config.config import OUTBOUND_PER_CUSTOMER_COOLDOWN_SECONDS
|
||||||
|
cooldown = max(0, int(OUTBOUND_PER_CUSTOMER_COOLDOWN_SECONDS))
|
||||||
|
except Exception:
|
||||||
|
cooldown = 5
|
||||||
|
if cooldown > 0:
|
||||||
|
ckey = f"{original_msg.get('acc_id', '')}:{original_msg.get('from_id', '')}"
|
||||||
|
now_mono = time.monotonic()
|
||||||
|
last = client._last_reply_sent_at.get(ckey, 0.0)
|
||||||
|
if (now_mono - last) < cooldown:
|
||||||
|
client._activity_log(
|
||||||
|
"send_reply_throttled",
|
||||||
|
trace_id=trace_id,
|
||||||
|
key=ckey,
|
||||||
|
cooldown_s=cooldown,
|
||||||
|
msg=str(reply_content),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
client._last_reply_sent_at[ckey] = now_mono
|
||||||
|
|
||||||
|
shop_id = original_msg.get("acc_id", "")
|
||||||
|
|
||||||
|
# 根据轻简API文档:
|
||||||
|
# from_id = 客户ID(收消息方)
|
||||||
|
# cy_id = 非群聊时与 from_id 相同
|
||||||
|
customer_id = original_msg.get("from_id", "")
|
||||||
|
customer_name = original_msg.get("from_name", "")
|
||||||
|
|
||||||
|
allow_send, checked_reply, guard_reason = await ai_guard_outbound_reply(
|
||||||
|
client=client,
|
||||||
|
original_msg=original_msg,
|
||||||
|
reply_content=str(reply_content),
|
||||||
|
)
|
||||||
|
client._activity_log(
|
||||||
|
"reply_guard_decision",
|
||||||
|
trace_id=trace_id,
|
||||||
|
acc_id=shop_id,
|
||||||
|
customer_id=customer_id,
|
||||||
|
result="ok" if allow_send else "blocked",
|
||||||
|
reason=guard_reason,
|
||||||
|
original_reply=str(reply_content),
|
||||||
|
final_reply=str(checked_reply or ""),
|
||||||
|
)
|
||||||
|
if not allow_send:
|
||||||
|
return
|
||||||
|
|
||||||
|
reply_content = checked_reply or str(reply_content)
|
||||||
|
pass_send, _ = client._outbound_arbiter(
|
||||||
|
original_msg=original_msg,
|
||||||
|
reply_content=reply_content,
|
||||||
|
trace_id=trace_id,
|
||||||
|
)
|
||||||
|
if not pass_send:
|
||||||
|
return
|
||||||
|
|
||||||
|
reply = {
|
||||||
|
"msg_id": "",
|
||||||
|
"acc_id": shop_id,
|
||||||
|
"msg": reply_content,
|
||||||
|
"from_id": customer_id,
|
||||||
|
"from_name": customer_name,
|
||||||
|
"cy_id": customer_id,
|
||||||
|
"acc_type": original_msg.get("acc_type", ""),
|
||||||
|
"msg_type": 0,
|
||||||
|
"cy_name": customer_name,
|
||||||
|
}
|
||||||
|
client._log_outbound_once(original_msg, str(reply_content))
|
||||||
|
client._activity_log(
|
||||||
|
"send_reply_attempt",
|
||||||
|
trace_id=trace_id,
|
||||||
|
acc_id=shop_id,
|
||||||
|
customer_id=customer_id,
|
||||||
|
msg=str(reply_content),
|
||||||
|
)
|
||||||
|
reply["_trace_id"] = trace_id
|
||||||
|
await client.send_message(reply)
|
||||||
|
|
||||||
|
|
||||||
|
async def ai_generate_outbound_reply(client, original_msg: dict, reply_content: str) -> str:
|
||||||
|
"""
|
||||||
|
强制全量 AI 出站生成层:
|
||||||
|
- 所有普通文本外发先由 AI 生成最终话术;
|
||||||
|
- 控制命令/纯链接/转接指令直接绕过。
|
||||||
|
"""
|
||||||
|
text = (reply_content or "").strip()
|
||||||
|
if not text:
|
||||||
|
return text
|
||||||
|
if text.startswith("话术|") or "[转移会话]" in text or "TRANSFER_REQUESTED" in text:
|
||||||
|
return text
|
||||||
|
if re.fullmatch(r"https?://\S+", text):
|
||||||
|
return text
|
||||||
|
if not client._force_ai_generate_reply or not client.enable_agent or not client.agent or not client.AgentDeps:
|
||||||
|
return text
|
||||||
|
try:
|
||||||
|
deps = client.AgentDeps(
|
||||||
|
msg_id=str(original_msg.get("msg_id", "") or "outbound_generate"),
|
||||||
|
acc_id=str(original_msg.get("acc_id", "") or ""),
|
||||||
|
from_id=str(original_msg.get("from_id", "") or ""),
|
||||||
|
platform=str(original_msg.get("acc_type", "") or ""),
|
||||||
|
)
|
||||||
|
customer_msg = client.to_chinese(str(original_msg.get("msg", "") or ""))
|
||||||
|
prompt = (
|
||||||
|
"你是淘宝客服外发文案生成器。请根据“回复意图草稿”生成最终发给客户的话。\n"
|
||||||
|
"要求:\n"
|
||||||
|
"1) 保留原意,不新增价格/承诺/流程;\n"
|
||||||
|
"2) 自然像真人聊天,不用固定模板句;\n"
|
||||||
|
"3) 1-2句;\n"
|
||||||
|
"4) 只输出最终回复文本。\n\n"
|
||||||
|
f"客户原话: {customer_msg}\n"
|
||||||
|
f"回复意图草稿: {text}\n"
|
||||||
|
)
|
||||||
|
result = await client.agent.agent_natural_reply.run(prompt, deps=deps, message_history=[])
|
||||||
|
out = str(getattr(result, "output", "") or "").strip()
|
||||||
|
if not out:
|
||||||
|
return text
|
||||||
|
if out.startswith("话术|") or "[转移会话]" in out:
|
||||||
|
return text
|
||||||
|
client._activity_log(
|
||||||
|
"ai_generate_reply",
|
||||||
|
acc_id=str(original_msg.get("acc_id", "") or ""),
|
||||||
|
customer_id=str(original_msg.get("from_id", "") or ""),
|
||||||
|
draft=text[:160],
|
||||||
|
generated=out[:160],
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
except Exception as e:
|
||||||
|
client._activity_log(
|
||||||
|
"ai_generate_reply_error",
|
||||||
|
acc_id=str(original_msg.get("acc_id", "") or ""),
|
||||||
|
customer_id=str(original_msg.get("from_id", "") or ""),
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def colloquialize_outbound_reply(text: Any) -> Any:
|
||||||
|
"""统一外发口语化处理,避免机械话术。"""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return text
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return text
|
||||||
|
# 控制指令/转接命令不得改写
|
||||||
|
if raw.startswith("话术|") or "[转移会话]" in raw:
|
||||||
|
return text
|
||||||
|
# 纯链接不改
|
||||||
|
if re.fullmatch(r"https?://\S+", raw):
|
||||||
|
return text
|
||||||
|
|
||||||
|
out = raw
|
||||||
|
replacements = {
|
||||||
|
"我这边": "我这边",
|
||||||
|
"请您": "你",
|
||||||
|
"您好": "你好",
|
||||||
|
"稍后": "一会儿",
|
||||||
|
"可以的话": "可以的话",
|
||||||
|
"请稍等": "稍等哈",
|
||||||
|
"先不乱报价": "先不急着给你乱报",
|
||||||
|
"建议转人工评估更稳": "建议转人工看会更稳",
|
||||||
|
"统一报价": "一起报价",
|
||||||
|
"马上安排": "马上给你安排",
|
||||||
|
"确认我就安排": "你点头我就开做",
|
||||||
|
"收到,我看看哈": "收到,我先看下",
|
||||||
|
"收到,我找找刚才那几张": "收到,我把刚才那几张一起看下",
|
||||||
|
"这组图我这边暂时识别不稳定": "这组图我这边识别得不太稳",
|
||||||
|
"这组图我这边暂时识别异常": "这组图我这边刚才识别有点异常",
|
||||||
|
"你可以换一张更清晰的,我再给你准报价。": "你换张更清晰的发我,我再给你报准点。",
|
||||||
|
"你可以换清晰图再发我。": "你换张清晰点的再发我哈。",
|
||||||
|
"你可以稍后再发我。": "你晚点再发我也行。",
|
||||||
|
"收到付款,我马上安排处理,有需要第一时间联系您": "收到付款啦,我马上安排处理,有进展第一时间告诉你",
|
||||||
|
"亲,正在为您转接人工客服,请稍等~": "我这就给你转人工,稍等哈~",
|
||||||
|
}
|
||||||
|
for k, v in replacements.items():
|
||||||
|
out = out.replace(k, v)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def ai_guard_outbound_reply(client, original_msg: dict, reply_content: str) -> tuple[bool, str, str]:
|
||||||
|
"""
|
||||||
|
专用AI质检:发送前判断“这句是否该发”,可拦截或改写。
|
||||||
|
读取当前客户在当前店铺的完整对话上下文。
|
||||||
|
"""
|
||||||
|
text = (reply_content or "").strip()
|
||||||
|
if not text:
|
||||||
|
return False, "", "empty_reply"
|
||||||
|
if text.startswith("话术|") or "[转移会话]" in text:
|
||||||
|
return True, text, "command_bypass"
|
||||||
|
if not client._reply_guard_enabled or not client.enable_agent or not client.agent or not client.AgentDeps:
|
||||||
|
return True, text, "guard_disabled"
|
||||||
|
try:
|
||||||
|
from db.chat_log_db import get_conversation
|
||||||
|
import json as _json
|
||||||
|
import re as _re
|
||||||
|
|
||||||
|
acc_id = str(original_msg.get("acc_id", "") or "")
|
||||||
|
customer_id = str(original_msg.get("from_id", "") or "")
|
||||||
|
if not customer_id:
|
||||||
|
return True, text, "no_customer_id"
|
||||||
|
|
||||||
|
# 默认读取较大窗口,尽量覆盖完整上下文;可用环境变量继续放大。
|
||||||
|
try:
|
||||||
|
max_rows = max(50, int(os.getenv("AI_REPLY_GUARD_CONTEXT_ROWS", "500")))
|
||||||
|
except Exception:
|
||||||
|
max_rows = 500
|
||||||
|
rows = get_conversation(customer_id=customer_id, limit=max_rows) or []
|
||||||
|
shop_rows = [r for r in rows if str(r.get("acc_id", "") or "") == acc_id] if acc_id else rows
|
||||||
|
|
||||||
|
context_lines = []
|
||||||
|
for r in shop_rows:
|
||||||
|
role = "客" if (r.get("direction") == "in") else "服"
|
||||||
|
msg = client.to_chinese((r.get("message") or "").strip())
|
||||||
|
if msg:
|
||||||
|
context_lines.append(f"{role}:{msg}")
|
||||||
|
context_text = "\n".join(context_lines) if context_lines else "无历史"
|
||||||
|
|
||||||
|
deps = client.AgentDeps(
|
||||||
|
msg_id=str(original_msg.get("msg_id", "") or "reply_guard"),
|
||||||
|
acc_id=acc_id,
|
||||||
|
from_id=customer_id,
|
||||||
|
platform=str(original_msg.get("acc_type", "") or ""),
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
"你是淘宝客服回复质检器。目标:判断候选回复是否和上下文一致,是否会造成重复触发式答复。\n"
|
||||||
|
"必须检查:\n"
|
||||||
|
"1) 是否答非所问;\n"
|
||||||
|
"2) 是否重复说“马上报价/继续发图”但当前上下文不需要;\n"
|
||||||
|
"3) 是否与历史状态冲突;\n"
|
||||||
|
"4) 语气是否自然可直接发给客户。\n"
|
||||||
|
"若不合适,给可直接发送的一句改写。\n"
|
||||||
|
"只输出 JSON:{\"allow\":true/false,\"rewrite\":\"...\",\"reason\":\"...\"}\n\n"
|
||||||
|
f"完整上下文(当前店铺):\n{context_text}\n\n"
|
||||||
|
f"客户当前消息:{client.to_chinese(original_msg.get('msg', '') or '')}\n"
|
||||||
|
f"候选回复:{text}\n"
|
||||||
|
)
|
||||||
|
result = await client.agent.agent_natural_reply.run(prompt, deps=deps, message_history=[])
|
||||||
|
raw = str(getattr(result, "output", "") or "").strip()
|
||||||
|
if not raw:
|
||||||
|
return True, text, "guard_empty_output"
|
||||||
|
m = _re.search(r"\{[\s\S]*\}", raw)
|
||||||
|
if not m:
|
||||||
|
return True, text, "guard_non_json"
|
||||||
|
obj = _json.loads(m.group(0))
|
||||||
|
allow = bool(obj.get("allow", True))
|
||||||
|
rewrite = str(obj.get("rewrite", "") or "").strip()
|
||||||
|
reason = str(obj.get("reason", "") or "").strip() or "guard_decision"
|
||||||
|
if allow:
|
||||||
|
return True, (rewrite or text), reason
|
||||||
|
if rewrite:
|
||||||
|
return True, rewrite, reason
|
||||||
|
return False, "", reason
|
||||||
|
except Exception as e:
|
||||||
|
return True, text, f"guard_error:{e}"
|
||||||
8
scripts/run_test_ai_chat.ps1
Normal file
8
scripts/run_test_ai_chat.ps1
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
|
||||||
|
# Use a writable uv cache path on Windows to avoid permission issues
|
||||||
|
# with default cache locations in restricted environments.
|
||||||
|
$env:UV_CACHE_DIR = Join-Path $env:TEMP "uv-cache-tw-runtime"
|
||||||
|
New-Item -ItemType Directory -Force $env:UV_CACHE_DIR | Out-Null
|
||||||
|
|
||||||
|
uv run tests\test_ai_chat.py
|
||||||
@@ -4,6 +4,8 @@ AI Agent 对话测试脚本
|
|||||||
"""
|
"""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# 颜色代码
|
# 颜色代码
|
||||||
@@ -18,14 +20,29 @@ COLORS = {
|
|||||||
'reset': '\033[0m',
|
'reset': '\033[0m',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Windows PowerShell defaults to GBK in some environments.
|
||||||
|
# Make stdout/stderr robust for Unicode logs used by this test script.
|
||||||
|
for stream_name in ("stdout", "stderr"):
|
||||||
|
stream = getattr(sys, stream_name, None)
|
||||||
|
if stream and hasattr(stream, "reconfigure"):
|
||||||
|
try:
|
||||||
|
stream.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Ensure project root is importable when running as `uv run tests/test_ai_chat.py`.
|
||||||
|
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||||||
|
if PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.insert(0, PROJECT_ROOT)
|
||||||
|
DB_PATH = Path(PROJECT_ROOT) / "db" / "chat_log_db" / "chats.db"
|
||||||
|
|
||||||
def cprint(text, color='reset'):
|
def cprint(text, color='reset'):
|
||||||
print(f"{COLORS.get(color, '')}{text}{COLORS['reset']}")
|
print(f"{COLORS.get(color, '')}{text}{COLORS['reset']}")
|
||||||
|
|
||||||
def check_database():
|
def check_database():
|
||||||
"""检查数据库内容"""
|
"""检查数据库内容"""
|
||||||
db_path = 'db/chat_log_db/chats.db'
|
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.execute("SELECT COUNT(*) FROM chat_logs")
|
cursor = conn.execute("SELECT COUNT(*) FROM chat_logs")
|
||||||
count = cursor.fetchone()[0]
|
count = cursor.fetchone()[0]
|
||||||
|
|
||||||
@@ -65,7 +82,7 @@ async def test_customer_conversation(customer_id, customer_name, limit=5):
|
|||||||
cprint(f"{'='*70}\n", 'cyan')
|
cprint(f"{'='*70}\n", 'cyan')
|
||||||
|
|
||||||
# 获取对话记录
|
# 获取对话记录
|
||||||
conn = sqlite3.connect('db/chat_log_db/chats.db')
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.execute("""
|
cursor = conn.execute("""
|
||||||
SELECT direction, message, timestamp
|
SELECT direction, message, timestamp
|
||||||
FROM chat_logs
|
FROM chat_logs
|
||||||
@@ -157,7 +174,7 @@ async def test_all_customers(customers, limit_per_customer=5):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 获取对话记录
|
# 获取对话记录
|
||||||
conn = sqlite3.connect('db/chat_log_db/chats.db')
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.execute("""
|
cursor = conn.execute("""
|
||||||
SELECT direction, message, timestamp
|
SELECT direction, message, timestamp
|
||||||
FROM chat_logs
|
FROM chat_logs
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from datetime import datetime, date, timedelta
|
from datetime import datetime, date, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ import httpx
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
logger = logging.getLogger("cs_agent")
|
||||||
|
|
||||||
WECHAT_WEBHOOK = os.getenv("WECHAT_WEBHOOK", "")
|
WECHAT_WEBHOOK = os.getenv("WECHAT_WEBHOOK", "")
|
||||||
SUMMARY_EMAIL = os.getenv("SUMMARY_EMAIL", "") # 收摘要的邮箱
|
SUMMARY_EMAIL = os.getenv("SUMMARY_EMAIL", "") # 收摘要的邮箱
|
||||||
@@ -130,7 +132,7 @@ async def _ai_summary(raw_text: str) -> str:
|
|||||||
async def _send_wechat(content: str):
|
async def _send_wechat(content: str):
|
||||||
"""推送到企业微信群机器人(markdown 格式,单条 ≤4096 字节自动分段)"""
|
"""推送到企业微信群机器人(markdown 格式,单条 ≤4096 字节自动分段)"""
|
||||||
if not WECHAT_WEBHOOK:
|
if not WECHAT_WEBHOOK:
|
||||||
print("[DailySummary] 未配置 WECHAT_WEBHOOK,跳过推送")
|
logger.info("[DailySummary] 未配置 WECHAT_WEBHOOK,跳过推送")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 企业微信单条 markdown 限 4096 字节,超长自动分段
|
# 企业微信单条 markdown 限 4096 字节,超长自动分段
|
||||||
@@ -148,11 +150,11 @@ async def _send_wechat(content: str):
|
|||||||
resp = await client.post(WECHAT_WEBHOOK, json=payload)
|
resp = await client.post(WECHAT_WEBHOOK, json=payload)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
if data.get("errcode") == 0:
|
if data.get("errcode") == 0:
|
||||||
print(f"[DailySummary] 企业微信推送成功(第{i+1}段)")
|
logger.info("[DailySummary] 企业微信推送成功(第%s段)", i + 1)
|
||||||
else:
|
else:
|
||||||
print(f"[DailySummary] 企业微信推送失败: {data}")
|
logger.warning("[DailySummary] 企业微信推送失败: %s", data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[DailySummary] 企业微信推送异常: {e}")
|
logger.exception("[DailySummary] 企业微信推送异常: %s", e)
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────
|
# ──────────────────────────────────────────
|
||||||
@@ -178,9 +180,9 @@ def _send_email(subject: str, body: str):
|
|||||||
s.starttls()
|
s.starttls()
|
||||||
s.login(email_sender.smtp_user, email_sender.smtp_password)
|
s.login(email_sender.smtp_user, email_sender.smtp_password)
|
||||||
s.sendmail(email_sender.smtp_user, [SUMMARY_EMAIL], msg.as_string())
|
s.sendmail(email_sender.smtp_user, [SUMMARY_EMAIL], msg.as_string())
|
||||||
print(f"[DailySummary] 日报邮件已发送至 {SUMMARY_EMAIL}")
|
logger.info("[DailySummary] 日报邮件已发送至 %s", SUMMARY_EMAIL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[DailySummary] 日报邮件发送失败: {e}")
|
logger.exception("[DailySummary] 日报邮件发送失败: %s", e)
|
||||||
|
|
||||||
|
|
||||||
# ──────────────────────────────────────────
|
# ──────────────────────────────────────────
|
||||||
@@ -244,7 +246,7 @@ async def send_daily_summary(target_date: str = ""):
|
|||||||
if not target_date:
|
if not target_date:
|
||||||
target_date = datetime.now().strftime("%Y-%m-%d")
|
target_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
print(f"[DailySummary] 开始生成 {target_date} 日报...")
|
logger.info("[DailySummary] 开始生成 %s 日报...", target_date)
|
||||||
|
|
||||||
raw_text = _build_stats_text(target_date)
|
raw_text = _build_stats_text(target_date)
|
||||||
ai_text = await _ai_summary(raw_text)
|
ai_text = await _ai_summary(raw_text)
|
||||||
@@ -258,7 +260,7 @@ async def send_daily_summary(target_date: str = ""):
|
|||||||
email_body = f"{ai_text}\n\n{'='*40}\n\n{raw_text}"
|
email_body = f"{ai_text}\n\n{'='*40}\n\n{raw_text}"
|
||||||
_send_email(title, email_body)
|
_send_email(title, email_body)
|
||||||
|
|
||||||
print(f"[DailySummary] 日报推送完成")
|
logger.info("[DailySummary] 日报推送完成")
|
||||||
return ai_text
|
return ai_text
|
||||||
|
|
||||||
|
|
||||||
@@ -268,7 +270,7 @@ async def send_daily_summary(target_date: str = ""):
|
|||||||
|
|
||||||
async def scheduler():
|
async def scheduler():
|
||||||
"""每天 SEND_HOUR:SEND_MINUTE 触发日报"""
|
"""每天 SEND_HOUR:SEND_MINUTE 触发日报"""
|
||||||
print(f"[DailySummary] 定时日报已启动,发送时间 {SEND_HOUR:02d}:{SEND_MINUTE:02d}")
|
logger.info("[DailySummary] 定时日报已启动,发送时间 %02d:%02d", SEND_HOUR, SEND_MINUTE)
|
||||||
sent_today: Optional[str] = None # 记录已发日期,防重复
|
sent_today: Optional[str] = None # 记录已发日期,防重复
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -280,7 +282,7 @@ async def scheduler():
|
|||||||
try:
|
try:
|
||||||
await send_daily_summary(today)
|
await send_daily_summary(today)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[DailySummary] 日报生成出错: {e}")
|
logger.exception("[DailySummary] 日报生成出错: %s", e)
|
||||||
|
|
||||||
# 每 30 秒检查一次
|
# 每 30 秒检查一次
|
||||||
await asyncio.sleep(30)
|
await asyncio.sleep(30)
|
||||||
@@ -294,5 +296,5 @@ if __name__ == "__main__":
|
|||||||
import sys
|
import sys
|
||||||
target = sys.argv[1] if len(sys.argv) > 1 else ""
|
target = sys.argv[1] if len(sys.argv) > 1 else ""
|
||||||
result = asyncio.run(send_daily_summary(target))
|
result = asyncio.run(send_daily_summary(target))
|
||||||
print("\n=== AI 摘要 ===")
|
logger.info("\n=== AI 摘要 ===")
|
||||||
print(result)
|
logger.info(result)
|
||||||
|
|||||||
@@ -4,12 +4,14 @@
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
logger = logging.getLogger("cs_agent")
|
||||||
|
|
||||||
_last_push: dict[tuple[str, str], tuple[str, str, float]] = {}
|
_last_push: dict[tuple[str, str], tuple[str, str, float]] = {}
|
||||||
|
|
||||||
@@ -38,6 +40,7 @@ def _get_recent_conversation(customer_id: str, acc_id: str, last_n: int = 8) ->
|
|||||||
from db.chat_log_db import get_recent_conversation
|
from db.chat_log_db import get_recent_conversation
|
||||||
return get_recent_conversation(customer_id, acc_id, limit=last_n)
|
return get_recent_conversation(customer_id, acc_id, limit=last_n)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
logger.debug("[WechatChatLog] 获取近期对话失败,返回空列表", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +71,7 @@ async def push_chat_to_wechat(
|
|||||||
return
|
return
|
||||||
_last_push[key] = ((customer_msg or ""), (reply_msg or ""), now)
|
_last_push[key] = ((customer_msg or ""), (reply_msg or ""), now)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.debug("[WechatChatLog] 去重检查异常,忽略本次去重", exc_info=True)
|
||||||
reply_msg = _truncate(reply_msg, 300)
|
reply_msg = _truncate(reply_msg, 300)
|
||||||
ts = datetime.now().strftime("%H:%M")
|
ts = datetime.now().strftime("%H:%M")
|
||||||
shop = acc_id or "未知店铺"
|
shop = acc_id or "未知店铺"
|
||||||
@@ -109,11 +112,11 @@ async def push_chat_to_wechat(
|
|||||||
)
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
if data.get("errcode") == 0:
|
if data.get("errcode") == 0:
|
||||||
pass # 成功静默
|
return
|
||||||
else:
|
else:
|
||||||
print(f"[WechatChatLog] 推送失败: {data}")
|
logger.warning("[WechatChatLog] 推送失败: %s", data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WechatChatLog] 推送异常: {e}")
|
logger.exception("[WechatChatLog] 推送异常: %s", e)
|
||||||
|
|
||||||
|
|
||||||
async def send_morning_startup():
|
async def send_morning_startup():
|
||||||
@@ -129,14 +132,14 @@ async def send_morning_startup():
|
|||||||
webhook,
|
webhook,
|
||||||
json={"msgtype": "markdown", "markdown": {"content": content}},
|
json={"msgtype": "markdown", "markdown": {"content": content}},
|
||||||
)
|
)
|
||||||
print(f"[WechatChatLog] 早8点启动消息已发送")
|
logger.info("[WechatChatLog] 早8点启动消息已发送")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WechatChatLog] 启动消息发送失败: {e}")
|
logger.exception("[WechatChatLog] 启动消息发送失败: %s", e)
|
||||||
|
|
||||||
|
|
||||||
async def morning_startup_scheduler():
|
async def morning_startup_scheduler():
|
||||||
"""每天 8:00 发送启动消息"""
|
"""每天 8:00 发送启动消息"""
|
||||||
print("[WechatChatLog] 早8点启动消息定时任务已启动")
|
logger.info("[WechatChatLog] 早8点启动消息定时任务已启动")
|
||||||
sent_today = None
|
sent_today = None
|
||||||
while True:
|
while True:
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|||||||
Reference in New Issue
Block a user