110 lines
3.8 KiB
Python
110 lines
3.8 KiB
Python
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
from datetime import datetime
|
||
from core.orchestrator import init_orchestrator
|
||
from core.websocket_connection_flow import connect_flow, receive_messages_flow
|
||
from core.websocket_send_flow import send_message_flow
|
||
from utils.observability import emit_activity
|
||
|
||
logger = logging.getLogger("cs_agent")
|
||
|
||
class QingjianAPIClient:
|
||
"""
|
||
重构后的轻简API客户端 (协议全复刻版)
|
||
"""
|
||
|
||
def __init__(self, uri=None, enable_agent: bool = True, worker_id: int = -1, worker_count: int = 1):
|
||
from config.config import QINGJIAN_WS_URI
|
||
self.uri = uri or QINGJIAN_WS_URI
|
||
self.websocket = None
|
||
self.running = True
|
||
self.logger = logger
|
||
self.enable_agent = enable_agent
|
||
|
||
# 多进程分片逻辑
|
||
self.worker_id = worker_id
|
||
self.worker_count = worker_count
|
||
if self.worker_id >= 0:
|
||
logger.info(f"[WebSocket] 启用分片模式: Worker {self.worker_id}/{self.worker_count}")
|
||
|
||
# 初始化新架构总指挥部
|
||
self.orchestrator = init_orchestrator(ws_client=self)
|
||
logger.info("[WebSocket] 新架构 Orchestrator 已就绪。")
|
||
|
||
def _activity_log(self, event: str, **kwargs):
|
||
emit_activity(logger, event=event, **kwargs)
|
||
|
||
async def connect(self):
|
||
await connect_flow(self)
|
||
|
||
async def receive_messages(self):
|
||
await receive_messages_flow(self)
|
||
|
||
def _should_handle(self, customer_id: str) -> bool:
|
||
"""分片判定:这个客户归我管吗?"""
|
||
if self.worker_id < 0 or self.worker_count <= 1:
|
||
return True
|
||
|
||
# 如果没有 customer_id,为了安全起见,只让 Worker 0 处理
|
||
if not customer_id:
|
||
return self.worker_id == 0
|
||
|
||
import hashlib
|
||
# 使用稳定的哈希算法分配客户
|
||
hash_val = int(hashlib.md5(str(customer_id).encode("utf-8")).hexdigest(), 16)
|
||
return (hash_val % self.worker_count) == self.worker_id
|
||
|
||
async def handle_message(self, message):
|
||
"""收到消息处理"""
|
||
try:
|
||
data = json.loads(message)
|
||
# 预提取客户ID用于分片判定
|
||
customer_id = str(data.get("cy_id") or data.get("from_id") or "")
|
||
if not self._should_handle(customer_id):
|
||
return
|
||
|
||
await self.orchestrator.on_raw_message_received(platform="qianniu", raw_data=data)
|
||
except Exception as e:
|
||
raw_preview = str(message).replace("\n", "\\n")
|
||
if len(raw_preview) > 300:
|
||
raw_preview = raw_preview[:300] + "..."
|
||
logger.error(f"[WebSocket] 处理消息异常: {e} raw={raw_preview}")
|
||
|
||
async def send(self, customer_id: str, acc_id: str, acc_type: str, content: str, msg_type: int = 0):
|
||
"""
|
||
【协议全复刻】严格按照 legacy/websocket_outbound_flow.py 的结构
|
||
"""
|
||
# 注意:在这里 from_id 竟然填的是 customer_id,这是逆向接口的特殊要求
|
||
msg_payload = {
|
||
"msg_id": "",
|
||
"acc_id": acc_id,
|
||
"msg": content,
|
||
"from_id": customer_id,
|
||
"from_name": "",
|
||
"cy_id": customer_id,
|
||
"acc_type": acc_type,
|
||
"msg_type": msg_type,
|
||
"cy_name": "",
|
||
}
|
||
await self.send_message(msg_payload)
|
||
|
||
async def send_message(self, message_dict: dict):
|
||
"""底层的 WebSocket 发送"""
|
||
await send_message_flow(self, message_dict)
|
||
|
||
def get_time(self):
|
||
return datetime.now().strftime("%H:%M:%S")
|
||
|
||
async def run(self):
|
||
await self.connect()
|
||
await self.receive_messages()
|
||
|
||
if __name__ == "__main__":
|
||
client = QingjianAPIClient()
|
||
try:
|
||
asyncio.run(client.run())
|
||
except KeyboardInterrupt:
|
||
logger.info("已停止")
|