331 lines
12 KiB
Python
331 lines
12 KiB
Python
"""
|
||
AI Agent 对话测试脚本
|
||
从数据库加载聊天记录,测试 AI 回复效果
|
||
"""
|
||
import sqlite3
|
||
import asyncio
|
||
import sys
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
# 颜色代码
|
||
COLORS = {
|
||
'header': '\033[95m\033[1m',
|
||
'customer': '\033[94m',
|
||
'agent': '\033[92m',
|
||
'system': '\033[90m',
|
||
'price': '\033[93m',
|
||
'error': '\033[91m',
|
||
'cyan': '\033[96m',
|
||
'reset': '\033[0m',
|
||
}
|
||
|
||
# Windows PowerShell defaults to GBK in some environments.
|
||
# Make stdout/stderr robust for Unicode logs used by this test script.
|
||
for stream_name in ("stdout", "stderr"):
|
||
stream = getattr(sys, stream_name, None)
|
||
if stream and hasattr(stream, "reconfigure"):
|
||
try:
|
||
stream.reconfigure(encoding="utf-8", errors="replace")
|
||
except Exception:
|
||
pass
|
||
|
||
# Ensure project root is importable when running as `uv run tests/test_ai_chat.py`.
|
||
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
||
if PROJECT_ROOT not in sys.path:
|
||
sys.path.insert(0, PROJECT_ROOT)
|
||
DB_PATH = Path(PROJECT_ROOT) / "db" / "chat_log_db" / "chats.db"
|
||
|
||
def cprint(text, color='reset'):
|
||
print(f"{COLORS.get(color, '')}{text}{COLORS['reset']}")
|
||
|
||
def check_database():
|
||
"""检查数据库内容"""
|
||
try:
|
||
conn = sqlite3.connect(DB_PATH)
|
||
cursor = conn.execute("SELECT COUNT(*) FROM chat_logs")
|
||
count = cursor.fetchone()[0]
|
||
|
||
if count == 0:
|
||
cprint(f"\n✗ 数据库为空,没有聊天记录", 'error')
|
||
cprint("提示:需要先有一些聊天记录才能测试", 'system')
|
||
conn.close()
|
||
return None
|
||
|
||
cprint(f"\n✓ 数据库连接成功!共 {count} 条聊天记录", 'system')
|
||
|
||
# 获取客户列表
|
||
cursor = conn.execute("""
|
||
SELECT customer_id, customer_name, COUNT(*) as cnt, MAX(timestamp) as last
|
||
FROM chat_logs
|
||
GROUP BY customer_id
|
||
ORDER BY cnt DESC
|
||
LIMIT 20
|
||
""")
|
||
customers = cursor.fetchall()
|
||
|
||
cprint(f"\n找到 {len(customers)} 个客户:", 'cyan')
|
||
for i, (cid, name, cnt, last) in enumerate(customers, 1):
|
||
cprint(f" {i:2d}. {name or cid:30s} | {cnt:4d}条 | 最后:{last}", 'customer')
|
||
|
||
conn.close()
|
||
return customers
|
||
|
||
except Exception as e:
|
||
cprint(f"\n✗ 数据库检查失败:{e}", 'error')
|
||
return None
|
||
|
||
async def test_customer_conversation(customer_id, customer_name, limit=5):
|
||
"""测试某个客户的对话"""
|
||
cprint(f"\n{'='*70}", 'cyan')
|
||
cprint(f"测试客户:{customer_name or customer_id}", 'header')
|
||
cprint(f"{'='*70}\n", 'cyan')
|
||
|
||
# 获取对话记录
|
||
conn = sqlite3.connect(DB_PATH)
|
||
cursor = conn.execute("""
|
||
SELECT direction, message, timestamp
|
||
FROM chat_logs
|
||
WHERE customer_id = ?
|
||
ORDER BY timestamp ASC
|
||
LIMIT ?
|
||
""", (customer_id, limit))
|
||
conversations = cursor.fetchall()
|
||
conn.close()
|
||
|
||
if not conversations:
|
||
cprint(" 该客户没有对话记录", 'system')
|
||
return
|
||
|
||
# 初始化 AI Agent
|
||
try:
|
||
from core.pydantic_ai_agent import CustomerServiceAgent, CustomerMessage
|
||
agent = CustomerServiceAgent(skills_dir="skills")
|
||
cprint("✓ AI Agent 已加载", 'system')
|
||
except Exception as e:
|
||
cprint(f"✗ AI Agent 加载失败:{e}", 'error')
|
||
return
|
||
|
||
# 模拟对话
|
||
for i, (direction, message, timestamp) in enumerate(conversations, 1):
|
||
if direction == 'in':
|
||
# 客户消息
|
||
cprint(f"\n【消息 {i}/{len(conversations)}】{timestamp}", 'system')
|
||
cprint(f"客户:{message}", 'customer')
|
||
|
||
# 创建测试消息
|
||
test_msg = CustomerMessage(
|
||
msg_id=f"test_{i}",
|
||
acc_id="test_shop",
|
||
msg=message,
|
||
from_id=customer_id,
|
||
from_name=customer_name or "测试",
|
||
cy_id=customer_id,
|
||
acc_type="AliWorkbench",
|
||
msg_type=0,
|
||
cy_name=customer_name or "测试",
|
||
goods_name="专业找图",
|
||
goods_order=""
|
||
)
|
||
|
||
# 获取 AI 回复
|
||
start = datetime.now()
|
||
try:
|
||
response = await agent.process_message(test_msg)
|
||
elapsed = (datetime.now() - start).total_seconds() * 1000
|
||
|
||
if response.should_reply:
|
||
cprint(f"AI [{elapsed:.0f}ms]: {response.reply}", 'agent')
|
||
|
||
# 检测特殊内容
|
||
if any(kw in response.reply for kw in ['元', '块', '价格']):
|
||
cprint(" ↳ [价格信息]", 'price')
|
||
if response.need_transfer:
|
||
cprint(" ↳ [转人工]", 'error')
|
||
else:
|
||
cprint("[AI 静默]", 'system')
|
||
|
||
except Exception as e:
|
||
cprint(f"✗ AI 回复失败:{e}", 'error')
|
||
|
||
elif direction == 'out':
|
||
cprint(f"\n[历史回复] {timestamp}", 'system')
|
||
cprint(f"客服:{message}", 'system')
|
||
|
||
cprint(f"\n{'='*70}", 'cyan')
|
||
|
||
async def test_all_customers(customers, limit_per_customer=5):
|
||
"""批量测试所有客户"""
|
||
cprint(f"\n{'='*70}", 'header')
|
||
cprint(f" 开始批量测试 {len(customers)} 个客户", 'header')
|
||
cprint(f" 每个客户测试前 {limit_per_customer} 条消息", 'header')
|
||
cprint(f"{'='*70}\n", 'header')
|
||
|
||
total_msgs = 0
|
||
total_replies = 0
|
||
|
||
for i, (cid, name, cnt, _) in enumerate(customers, 1):
|
||
cprint(f"\n\n{'='*70}", 'cyan')
|
||
cprint(f"进度:{i}/{len(customers)} - {name or cid} ({cnt}条消息)", 'cyan')
|
||
cprint(f"{'='*70}", 'cyan')
|
||
|
||
if cnt == 0:
|
||
cprint(" 跳过(无消息记录)", 'system')
|
||
continue
|
||
|
||
# 获取对话记录
|
||
conn = sqlite3.connect(DB_PATH)
|
||
cursor = conn.execute("""
|
||
SELECT direction, message, timestamp
|
||
FROM chat_logs
|
||
WHERE customer_id = ?
|
||
ORDER BY timestamp ASC
|
||
LIMIT ?
|
||
""", (cid, limit_per_customer))
|
||
conversations = cursor.fetchall()
|
||
conn.close()
|
||
|
||
# 初始化 AI Agent(只初始化一次)
|
||
try:
|
||
from core.pydantic_ai_agent import CustomerServiceAgent, CustomerMessage
|
||
if i == 1: # 第一个客户时初始化
|
||
agent = CustomerServiceAgent(skills_dir="skills")
|
||
cprint("✓ AI Agent 已加载", 'system')
|
||
except Exception as e:
|
||
cprint(f"✗ AI Agent 加载失败:{e}", 'error')
|
||
return
|
||
|
||
# 模拟对话
|
||
for j, (direction, message, timestamp) in enumerate(conversations, 1):
|
||
if direction == 'in':
|
||
total_msgs += 1
|
||
|
||
# 创建测试消息
|
||
test_msg = CustomerMessage(
|
||
msg_id=f"test_{i}_{j}",
|
||
acc_id="test_shop",
|
||
msg=message,
|
||
from_id=cid,
|
||
from_name=name or "测试",
|
||
cy_id=cid,
|
||
acc_type="AliWorkbench",
|
||
msg_type=0,
|
||
cy_name=name or "测试",
|
||
goods_name="专业找图",
|
||
goods_order=""
|
||
)
|
||
|
||
# 获取 AI 回复
|
||
start = datetime.now()
|
||
try:
|
||
response = await agent.process_message(test_msg)
|
||
elapsed = (datetime.now() - start).total_seconds() * 1000
|
||
|
||
if response.should_reply:
|
||
total_replies += 1
|
||
cprint(f"\n[{i}/{len(customers)}] {name or cid} - 消息 {j}", 'system')
|
||
cprint(f"客户:{message}", 'customer')
|
||
cprint(f"AI [{elapsed:.0f}ms]: {response.reply}", 'agent')
|
||
|
||
# 检测特殊内容
|
||
if any(kw in response.reply for kw in ['元', '块', '价格']):
|
||
cprint(" ↳ [价格信息]", 'price')
|
||
if response.need_transfer:
|
||
cprint(" ↳ [转人工]", 'error')
|
||
else:
|
||
cprint(f"\n[{i}/{len(customers)}] [AI 静默]", 'system')
|
||
|
||
except Exception as e:
|
||
cprint(f"✗ AI 回复失败:{e}", 'error')
|
||
|
||
# 每个客户之间休息一下
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 统计结果
|
||
cprint(f"\n\n{'='*70}", 'header')
|
||
cprint(f" 批量测试完成!", 'header')
|
||
cprint(f"{'='*70}", 'header')
|
||
cprint(f"\n统计:", 'system')
|
||
cprint(f" 测试客户数:{len(customers)}", 'cyan')
|
||
cprint(f" 处理消息数:{total_msgs}", 'cyan')
|
||
cprint(f" AI 回复数:{total_replies}", 'cyan')
|
||
if total_msgs > 0:
|
||
reply_rate = (total_replies / total_msgs) * 100
|
||
cprint(f" 回复率:{reply_rate:.1f}%", 'cyan')
|
||
|
||
async def main():
|
||
cprint("="*70, 'header')
|
||
cprint(" AI Agent 对话测试", 'header')
|
||
cprint(" 从数据库加载聊天记录,测试 AI 回复效果", 'header')
|
||
cprint("="*70, 'header')
|
||
|
||
# 检查数据库
|
||
customers = check_database()
|
||
if not customers:
|
||
return
|
||
|
||
# 选择测试模式
|
||
cprint(f"\n请选择测试模式:", 'cyan')
|
||
cprint(f" 1. 交互式测试 (手动选择客户)", 'customer')
|
||
cprint(f" 2. 批量测试所有客户 (自动)", 'agent')
|
||
cprint(f" 3. 快速测试前 5 个客户", 'price')
|
||
cprint(f" q. 退出", 'system')
|
||
|
||
mode = input("\n选择:").strip().lower()
|
||
|
||
if mode == 'q':
|
||
cprint("\n测试结束!", 'system')
|
||
return
|
||
|
||
try:
|
||
if mode == '1':
|
||
# 交互式测试
|
||
cprint(f"\n请输入客户编号 (1-{len(customers)}) 进行测试:", 'cyan')
|
||
|
||
while True:
|
||
try:
|
||
choice = input("\n选择:").strip()
|
||
|
||
if choice.lower() == 'q':
|
||
cprint("\n测试结束!", 'system')
|
||
return
|
||
|
||
choice_num = int(choice)
|
||
if 1 <= choice_num <= len(customers):
|
||
cid, name, cnt, _ = customers[choice_num - 1]
|
||
await test_customer_conversation(cid, name or cid, limit=min(cnt, 10))
|
||
else:
|
||
cprint(f"请输入 1-{len(customers)} 之间的数字", 'error')
|
||
|
||
except ValueError:
|
||
cprint("请输入有效数字或 q 退出", 'error')
|
||
except KeyboardInterrupt:
|
||
cprint("\n\n测试中断", 'error')
|
||
return
|
||
except Exception as e:
|
||
cprint(f"错误:{e}", 'error')
|
||
|
||
elif mode == '2':
|
||
# 批量测试所有客户
|
||
await test_all_customers(customers, limit_per_customer=5)
|
||
|
||
elif mode == '3':
|
||
# 快速测试前 5 个客户
|
||
top_5 = customers[:5]
|
||
cprint(f"\n快速测试前 5 个客户...", 'cyan')
|
||
await test_all_customers(top_5, limit_per_customer=5)
|
||
|
||
else:
|
||
cprint("无效的选择", 'error')
|
||
|
||
except KeyboardInterrupt:
|
||
cprint("\n\n测试中断", 'error')
|
||
except Exception as e:
|
||
cprint(f"错误:{e}", 'error')
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
asyncio.run(main())
|
||
except Exception as e:
|
||
cprint(f"\n程序异常:{e}", 'error')
|