import logging import asyncio from typing import Optional, List, Any from datetime import datetime from db.customer_db import db as customer_db from db.image_tasks_db import db as task_db from db.chat_log_db import log_message, get_conversation logger = logging.getLogger("cs_agent") class DataRepository: """ 异步数据仓库:使用 asyncio.to_thread 屏蔽底层同步 IO 阻塞。 """ def __init__(self): self.customer_db = customer_db self.task_db = task_db # --- 聊天记录 (异步化) --- async def save_chat(self, platform: str, user_id: str, content: str, direction: str, acc_id: str = "", image_urls: list = None): """异步持久化存储聊天记录""" # 将图片URL列表转为\n分隔的字符串 urls_str = "\n".join(image_urls) if image_urls else "" return await asyncio.to_thread( log_message, customer_id=user_id, message=content, direction=direction, platform=platform, acc_id=acc_id, image_urls=urls_str ) async def get_chat_history(self, user_id: str, limit: int = 10, acc_id: str = "") -> List[dict]: """异步获取历史记录""" rows = await asyncio.to_thread(get_conversation, user_id, limit=limit, acc_id=acc_id) history = [] for r in rows: role = "user" if r["direction"] == "in" else "assistant" history.append( { "role": role, "content": r["message"], "timestamp": r.get("timestamp", ""), } ) return history # --- 客户相关 (异步化) --- async def get_customer(self, platform: str, user_id: str): customer_key = f"{platform}:{user_id}" return await asyncio.to_thread(self.customer_db.get_customer, customer_key) # --- 任务相关 (异步化) --- async def create_task(self, platform: str, user_id: str, image_url: str, operation: str, requirements: str = ""): return await asyncio.to_thread( self.task_db.add_task, customer_id=user_id, platform=platform, original_image=image_url, operation=operation, requirements=requirements, status="pending" ) async def update_task_price(self, platform: str, user_id: str, price: float): """异步记录成交价""" return await asyncio.to_thread(self.task_db.update_price, user_id, platform, price) async def update_task_outcome(self, platform: str, user_id: str, outcome: str): """异步记录最终结局""" return await asyncio.to_thread(self.task_db.update_outcome, user_id, platform, outcome) # 全局异步仓库单例 repo = DataRepository()