fix: prevent duplicate replies in multi-worker routing

This commit is contained in:
2026-02-28 19:11:32 +08:00
parent cbe1f19311
commit baa1ca8c64
2 changed files with 28 additions and 9 deletions

View File

@@ -5,6 +5,7 @@ import re
import logging import logging
import random import random
import time import time
import hashlib
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -108,6 +109,7 @@ class QingjianAPIClient:
# 多进程分片支持 # 多进程分片支持
self.shard_keys: set = set() # 本进程负责的客户 key 集合 self.shard_keys: set = set() # 本进程负责的客户 key 集合
self.worker_id = int(os.getenv('AI_CS_WORKER_ID', '0')) self.worker_id = int(os.getenv('AI_CS_WORKER_ID', '0'))
self.worker_count = max(1, int(os.getenv('AI_CS_WORKER_COUNT', '1')))
# 初始化 Agent # 初始化 Agent
if self.enable_agent: if self.enable_agent:
@@ -168,6 +170,22 @@ class QingjianAPIClient:
self._customer_locks[key] = asyncio.Lock() self._customer_locks[key] = asyncio.Lock()
return self._customer_locks[key] return self._customer_locks[key]
def _is_owned_by_this_worker(self, customer_key: str) -> bool:
"""
多进程兜底路由:
- 若显式分片存在,用显式分片;
- 否则按 customer_key 哈希到固定 worker避免多进程重复处理同一消息。
"""
if self.shard_keys:
return customer_key in self.shard_keys
if self.worker_count <= 1:
return True
try:
h = int(hashlib.md5(customer_key.encode("utf-8")).hexdigest()[:8], 16)
return (h % self.worker_count) == self.worker_id
except Exception:
return self.worker_id == 0
async def _agent_reply_serialized(self, data: dict): async def _agent_reply_serialized(self, data: dict):
"""同客户串行 + 全局并发限制,再执行 agent_reply""" """同客户串行 + 全局并发限制,再执行 agent_reply"""
key = self._customer_key(data) key = self._customer_key(data)
@@ -208,11 +226,9 @@ class QingjianAPIClient:
try: try:
data = json.loads(message) data = json.loads(message)
# 多进程分片检查:只处理分配给本进程的客户 # 多进程分片检查:确保同一客户只由一个 worker 处理
if self.shard_keys:
customer_key = self._customer_key(data) customer_key = self._customer_key(data)
if customer_key not in self.shard_keys: if not self._is_owned_by_this_worker(customer_key):
# 不属于本进程的客户,跳过
return return
timestamp = self.get_time() timestamp = self.get_time()

View File

@@ -24,9 +24,10 @@ logger = logging.getLogger(__name__)
class WorkerProcess: class WorkerProcess:
"""工作进程""" """工作进程"""
def __init__(self, worker_id: int, shard_keys: List[str], enable_agent: bool = True): def __init__(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool = True):
self.worker_id = worker_id self.worker_id = worker_id
self.shard_keys = shard_keys self.shard_keys = shard_keys
self.num_workers = max(1, int(num_workers))
self.enable_agent = enable_agent self.enable_agent = enable_agent
self.process = None self.process = None
@@ -34,17 +35,18 @@ class WorkerProcess:
"""启动工作进程""" """启动工作进程"""
self.process = Process( self.process = Process(
target=self._run, target=self._run,
args=(self.worker_id, self.shard_keys, self.enable_agent), args=(self.worker_id, self.shard_keys, self.num_workers, self.enable_agent),
name=f"ai-cs-worker-{self.worker_id}" name=f"ai-cs-worker-{self.worker_id}"
) )
self.process.start() self.process.start()
logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})") logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})")
def _run(self, worker_id: int, shard_keys: List[str], enable_agent: bool): def _run(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool):
"""工作进程入口""" """工作进程入口"""
try: try:
# 设置进程环境变量 # 设置进程环境变量
os.environ['AI_CS_WORKER_ID'] = str(worker_id) os.environ['AI_CS_WORKER_ID'] = str(worker_id)
os.environ['AI_CS_WORKER_COUNT'] = str(max(1, int(num_workers)))
os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys) os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys)
# 导入并启动 WebSocket 客户端 # 导入并启动 WebSocket 客户端
@@ -120,6 +122,7 @@ class Coordinator:
worker = WorkerProcess( worker = WorkerProcess(
worker_id=worker_id, worker_id=worker_id,
shard_keys=shards.get(worker_id, []), shard_keys=shards.get(worker_id, []),
num_workers=self.num_workers,
enable_agent=self.enable_agent enable_agent=self.enable_agent
) )
worker.start() worker.start()