From baa1ca8c648cd80f3268ae860403f853424f7397 Mon Sep 17 00:00:00 2001 From: jimi <1847930177@qq.com> Date: Sat, 28 Feb 2026 19:11:32 +0800 Subject: [PATCH] fix: prevent duplicate replies in multi-worker routing --- core/websocket_client.py | 28 ++++++++++++++++++++++------ scripts/multi_process_launcher.py | 9 ++++++--- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/core/websocket_client.py b/core/websocket_client.py index c504e6b..bc876e4 100755 --- a/core/websocket_client.py +++ b/core/websocket_client.py @@ -5,6 +5,7 @@ import re import logging import random import time +import hashlib from collections import deque from datetime import datetime from pathlib import Path @@ -108,6 +109,7 @@ class QingjianAPIClient: # 多进程分片支持 self.shard_keys: set = set() # 本进程负责的客户 key 集合 self.worker_id = int(os.getenv('AI_CS_WORKER_ID', '0')) + self.worker_count = max(1, int(os.getenv('AI_CS_WORKER_COUNT', '1'))) # 初始化 Agent if self.enable_agent: @@ -168,6 +170,22 @@ class QingjianAPIClient: self._customer_locks[key] = asyncio.Lock() return self._customer_locks[key] + def _is_owned_by_this_worker(self, customer_key: str) -> bool: + """ + 多进程兜底路由: + - 若显式分片存在,用显式分片; + - 否则按 customer_key 哈希到固定 worker,避免多进程重复处理同一消息。 + """ + if self.shard_keys: + return customer_key in self.shard_keys + if self.worker_count <= 1: + return True + try: + h = int(hashlib.md5(customer_key.encode("utf-8")).hexdigest()[:8], 16) + return (h % self.worker_count) == self.worker_id + except Exception: + return self.worker_id == 0 + async def _agent_reply_serialized(self, data: dict): """同客户串行 + 全局并发限制,再执行 agent_reply""" key = self._customer_key(data) @@ -208,12 +226,10 @@ class QingjianAPIClient: try: data = json.loads(message) - # 多进程分片检查:只处理分配给本进程的客户 - if self.shard_keys: - customer_key = self._customer_key(data) - if customer_key not in self.shard_keys: - # 不属于本进程的客户,跳过 - return + # 多进程分片检查:确保同一客户只由一个 worker 处理 + customer_key = self._customer_key(data) + if not self._is_owned_by_this_worker(customer_key): + return timestamp = self.get_time() diff --git a/scripts/multi_process_launcher.py b/scripts/multi_process_launcher.py index f944537..46cfe0b 100644 --- a/scripts/multi_process_launcher.py +++ b/scripts/multi_process_launcher.py @@ -24,9 +24,10 @@ logger = logging.getLogger(__name__) class WorkerProcess: """工作进程""" - def __init__(self, worker_id: int, shard_keys: List[str], enable_agent: bool = True): + def __init__(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool = True): self.worker_id = worker_id self.shard_keys = shard_keys + self.num_workers = max(1, int(num_workers)) self.enable_agent = enable_agent self.process = None @@ -34,17 +35,18 @@ class WorkerProcess: """启动工作进程""" self.process = Process( target=self._run, - args=(self.worker_id, self.shard_keys, self.enable_agent), + args=(self.worker_id, self.shard_keys, self.num_workers, self.enable_agent), name=f"ai-cs-worker-{self.worker_id}" ) self.process.start() logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})") - def _run(self, worker_id: int, shard_keys: List[str], enable_agent: bool): + def _run(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool): """工作进程入口""" try: # 设置进程环境变量 os.environ['AI_CS_WORKER_ID'] = str(worker_id) + os.environ['AI_CS_WORKER_COUNT'] = str(max(1, int(num_workers))) os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys) # 导入并启动 WebSocket 客户端 @@ -120,6 +122,7 @@ class Coordinator: worker = WorkerProcess( worker_id=worker_id, shard_keys=shards.get(worker_id, []), + num_workers=self.num_workers, enable_agent=self.enable_agent ) worker.start()