79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
import logging
|
|
import asyncio
|
|
from typing import Optional, List, Any
|
|
from datetime import datetime
|
|
from db.customer_db import db as customer_db
|
|
from db.image_tasks_db import db as task_db
|
|
from db.chat_log_db import log_message, get_conversation
|
|
|
|
logger = logging.getLogger("cs_agent")
|
|
|
|
class DataRepository:
|
|
"""
|
|
异步数据仓库:使用 asyncio.to_thread 屏蔽底层同步 IO 阻塞。
|
|
"""
|
|
def __init__(self):
|
|
self.customer_db = customer_db
|
|
self.task_db = task_db
|
|
|
|
# --- 聊天记录 (异步化) ---
|
|
|
|
async def save_chat(self, platform: str, user_id: str, content: str, direction: str, acc_id: str = "", image_urls: list = None):
|
|
"""异步持久化存储聊天记录"""
|
|
# 将图片URL列表转为\n分隔的字符串
|
|
urls_str = "\n".join(image_urls) if image_urls else ""
|
|
return await asyncio.to_thread(
|
|
log_message,
|
|
customer_id=user_id,
|
|
message=content,
|
|
direction=direction,
|
|
platform=platform,
|
|
acc_id=acc_id,
|
|
image_urls=urls_str
|
|
)
|
|
|
|
async def get_chat_history(self, user_id: str, limit: int = 10, acc_id: str = "") -> List[dict]:
|
|
"""异步获取历史记录"""
|
|
rows = await asyncio.to_thread(get_conversation, user_id, limit=limit, acc_id=acc_id)
|
|
history = []
|
|
for r in rows:
|
|
role = "user" if r["direction"] == "in" else "assistant"
|
|
history.append(
|
|
{
|
|
"role": role,
|
|
"content": r["message"],
|
|
"timestamp": r.get("timestamp", ""),
|
|
}
|
|
)
|
|
return history
|
|
|
|
# --- 客户相关 (异步化) ---
|
|
|
|
async def get_customer(self, platform: str, user_id: str):
|
|
customer_key = f"{platform}:{user_id}"
|
|
return await asyncio.to_thread(self.customer_db.get_customer, customer_key)
|
|
|
|
# --- 任务相关 (异步化) ---
|
|
|
|
async def create_task(self, platform: str, user_id: str, image_url: str, operation: str, requirements: str = ""):
|
|
return await asyncio.to_thread(
|
|
self.task_db.add_task,
|
|
customer_id=user_id,
|
|
platform=platform,
|
|
original_image=image_url,
|
|
operation=operation,
|
|
requirements=requirements,
|
|
status="pending"
|
|
)
|
|
|
|
async def update_task_price(self, platform: str, user_id: str, price: float):
|
|
"""异步记录成交价"""
|
|
return await asyncio.to_thread(self.task_db.update_price, user_id, platform, price)
|
|
|
|
async def update_task_outcome(self, platform: str, user_id: str, outcome: str):
|
|
"""异步记录最终结局"""
|
|
return await asyncio.to_thread(self.task_db.update_outcome, user_id, platform, outcome)
|
|
|
|
# 全局异步仓库单例
|
|
repo = DataRepository()
|