Files
tw/tests/test_ai_chat.py

331 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI Agent 对话测试脚本
从数据库加载聊天记录,测试 AI 回复效果
"""
import sqlite3
import asyncio
import sys
from pathlib import Path
from datetime import datetime
# 颜色代码
COLORS = {
'header': '\033[95m\033[1m',
'customer': '\033[94m',
'agent': '\033[92m',
'system': '\033[90m',
'price': '\033[93m',
'error': '\033[91m',
'cyan': '\033[96m',
'reset': '\033[0m',
}
# Windows PowerShell defaults to GBK in some environments.
# Make stdout/stderr robust for Unicode logs used by this test script.
for stream_name in ("stdout", "stderr"):
stream = getattr(sys, stream_name, None)
if stream and hasattr(stream, "reconfigure"):
try:
stream.reconfigure(encoding="utf-8", errors="replace")
except Exception:
pass
# Ensure project root is importable when running as `uv run tests/test_ai_chat.py`.
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
DB_PATH = Path(PROJECT_ROOT) / "db" / "chat_log_db" / "chats.db"
def cprint(text, color='reset'):
print(f"{COLORS.get(color, '')}{text}{COLORS['reset']}")
def check_database():
"""检查数据库内容"""
try:
conn = sqlite3.connect(DB_PATH)
cursor = conn.execute("SELECT COUNT(*) FROM chat_logs")
count = cursor.fetchone()[0]
if count == 0:
cprint(f"\n✗ 数据库为空,没有聊天记录", 'error')
cprint("提示:需要先有一些聊天记录才能测试", 'system')
conn.close()
return None
cprint(f"\n✓ 数据库连接成功!共 {count} 条聊天记录", 'system')
# 获取客户列表
cursor = conn.execute("""
SELECT customer_id, customer_name, COUNT(*) as cnt, MAX(timestamp) as last
FROM chat_logs
GROUP BY customer_id
ORDER BY cnt DESC
LIMIT 20
""")
customers = cursor.fetchall()
cprint(f"\n找到 {len(customers)} 个客户:", 'cyan')
for i, (cid, name, cnt, last) in enumerate(customers, 1):
cprint(f" {i:2d}. {name or cid:30s} | {cnt:4d}条 | 最后:{last}", 'customer')
conn.close()
return customers
except Exception as e:
cprint(f"\n✗ 数据库检查失败:{e}", 'error')
return None
async def test_customer_conversation(customer_id, customer_name, limit=5):
"""测试某个客户的对话"""
cprint(f"\n{'='*70}", 'cyan')
cprint(f"测试客户:{customer_name or customer_id}", 'header')
cprint(f"{'='*70}\n", 'cyan')
# 获取对话记录
conn = sqlite3.connect(DB_PATH)
cursor = conn.execute("""
SELECT direction, message, timestamp
FROM chat_logs
WHERE customer_id = ?
ORDER BY timestamp ASC
LIMIT ?
""", (customer_id, limit))
conversations = cursor.fetchall()
conn.close()
if not conversations:
cprint(" 该客户没有对话记录", 'system')
return
# 初始化 AI Agent
try:
from core.pydantic_ai_agent import CustomerServiceAgent, CustomerMessage
agent = CustomerServiceAgent(skills_dir="skills")
cprint("✓ AI Agent 已加载", 'system')
except Exception as e:
cprint(f"✗ AI Agent 加载失败:{e}", 'error')
return
# 模拟对话
for i, (direction, message, timestamp) in enumerate(conversations, 1):
if direction == 'in':
# 客户消息
cprint(f"\n【消息 {i}/{len(conversations)}{timestamp}", 'system')
cprint(f"客户:{message}", 'customer')
# 创建测试消息
test_msg = CustomerMessage(
msg_id=f"test_{i}",
acc_id="test_shop",
msg=message,
from_id=customer_id,
from_name=customer_name or "测试",
cy_id=customer_id,
acc_type="AliWorkbench",
msg_type=0,
cy_name=customer_name or "测试",
goods_name="专业找图",
goods_order=""
)
# 获取 AI 回复
start = datetime.now()
try:
response = await agent.process_message(test_msg)
elapsed = (datetime.now() - start).total_seconds() * 1000
if response.should_reply:
cprint(f"AI [{elapsed:.0f}ms]: {response.reply}", 'agent')
# 检测特殊内容
if any(kw in response.reply for kw in ['', '', '价格']):
cprint(" ↳ [价格信息]", 'price')
if response.need_transfer:
cprint(" ↳ [转人工]", 'error')
else:
cprint("[AI 静默]", 'system')
except Exception as e:
cprint(f"✗ AI 回复失败:{e}", 'error')
elif direction == 'out':
cprint(f"\n[历史回复] {timestamp}", 'system')
cprint(f"客服:{message}", 'system')
cprint(f"\n{'='*70}", 'cyan')
async def test_all_customers(customers, limit_per_customer=5):
"""批量测试所有客户"""
cprint(f"\n{'='*70}", 'header')
cprint(f" 开始批量测试 {len(customers)} 个客户", 'header')
cprint(f" 每个客户测试前 {limit_per_customer} 条消息", 'header')
cprint(f"{'='*70}\n", 'header')
total_msgs = 0
total_replies = 0
for i, (cid, name, cnt, _) in enumerate(customers, 1):
cprint(f"\n\n{'='*70}", 'cyan')
cprint(f"进度:{i}/{len(customers)} - {name or cid} ({cnt}条消息)", 'cyan')
cprint(f"{'='*70}", 'cyan')
if cnt == 0:
cprint(" 跳过(无消息记录)", 'system')
continue
# 获取对话记录
conn = sqlite3.connect(DB_PATH)
cursor = conn.execute("""
SELECT direction, message, timestamp
FROM chat_logs
WHERE customer_id = ?
ORDER BY timestamp ASC
LIMIT ?
""", (cid, limit_per_customer))
conversations = cursor.fetchall()
conn.close()
# 初始化 AI Agent只初始化一次
try:
from core.pydantic_ai_agent import CustomerServiceAgent, CustomerMessage
if i == 1: # 第一个客户时初始化
agent = CustomerServiceAgent(skills_dir="skills")
cprint("✓ AI Agent 已加载", 'system')
except Exception as e:
cprint(f"✗ AI Agent 加载失败:{e}", 'error')
return
# 模拟对话
for j, (direction, message, timestamp) in enumerate(conversations, 1):
if direction == 'in':
total_msgs += 1
# 创建测试消息
test_msg = CustomerMessage(
msg_id=f"test_{i}_{j}",
acc_id="test_shop",
msg=message,
from_id=cid,
from_name=name or "测试",
cy_id=cid,
acc_type="AliWorkbench",
msg_type=0,
cy_name=name or "测试",
goods_name="专业找图",
goods_order=""
)
# 获取 AI 回复
start = datetime.now()
try:
response = await agent.process_message(test_msg)
elapsed = (datetime.now() - start).total_seconds() * 1000
if response.should_reply:
total_replies += 1
cprint(f"\n[{i}/{len(customers)}] {name or cid} - 消息 {j}", 'system')
cprint(f"客户:{message}", 'customer')
cprint(f"AI [{elapsed:.0f}ms]: {response.reply}", 'agent')
# 检测特殊内容
if any(kw in response.reply for kw in ['', '', '价格']):
cprint(" ↳ [价格信息]", 'price')
if response.need_transfer:
cprint(" ↳ [转人工]", 'error')
else:
cprint(f"\n[{i}/{len(customers)}] [AI 静默]", 'system')
except Exception as e:
cprint(f"✗ AI 回复失败:{e}", 'error')
# 每个客户之间休息一下
await asyncio.sleep(0.5)
# 统计结果
cprint(f"\n\n{'='*70}", 'header')
cprint(f" 批量测试完成!", 'header')
cprint(f"{'='*70}", 'header')
cprint(f"\n统计:", 'system')
cprint(f" 测试客户数:{len(customers)}", 'cyan')
cprint(f" 处理消息数:{total_msgs}", 'cyan')
cprint(f" AI 回复数:{total_replies}", 'cyan')
if total_msgs > 0:
reply_rate = (total_replies / total_msgs) * 100
cprint(f" 回复率:{reply_rate:.1f}%", 'cyan')
async def main():
cprint("="*70, 'header')
cprint(" AI Agent 对话测试", 'header')
cprint(" 从数据库加载聊天记录,测试 AI 回复效果", 'header')
cprint("="*70, 'header')
# 检查数据库
customers = check_database()
if not customers:
return
# 选择测试模式
cprint(f"\n请选择测试模式:", 'cyan')
cprint(f" 1. 交互式测试 (手动选择客户)", 'customer')
cprint(f" 2. 批量测试所有客户 (自动)", 'agent')
cprint(f" 3. 快速测试前 5 个客户", 'price')
cprint(f" q. 退出", 'system')
mode = input("\n选择:").strip().lower()
if mode == 'q':
cprint("\n测试结束!", 'system')
return
try:
if mode == '1':
# 交互式测试
cprint(f"\n请输入客户编号 (1-{len(customers)}) 进行测试:", 'cyan')
while True:
try:
choice = input("\n选择:").strip()
if choice.lower() == 'q':
cprint("\n测试结束!", 'system')
return
choice_num = int(choice)
if 1 <= choice_num <= len(customers):
cid, name, cnt, _ = customers[choice_num - 1]
await test_customer_conversation(cid, name or cid, limit=min(cnt, 10))
else:
cprint(f"请输入 1-{len(customers)} 之间的数字", 'error')
except ValueError:
cprint("请输入有效数字或 q 退出", 'error')
except KeyboardInterrupt:
cprint("\n\n测试中断", 'error')
return
except Exception as e:
cprint(f"错误:{e}", 'error')
elif mode == '2':
# 批量测试所有客户
await test_all_customers(customers, limit_per_customer=5)
elif mode == '3':
# 快速测试前 5 个客户
top_5 = customers[:5]
cprint(f"\n快速测试前 5 个客户...", 'cyan')
await test_all_customers(top_5, limit_per_customer=5)
else:
cprint("无效的选择", 'error')
except KeyboardInterrupt:
cprint("\n\n测试中断", 'error')
except Exception as e:
cprint(f"错误:{e}", 'error')
if __name__ == "__main__":
try:
asyncio.run(main())
except Exception as e:
cprint(f"\n程序异常:{e}", 'error')