fix: prevent duplicate replies in multi-worker routing
This commit is contained in:
@@ -5,6 +5,7 @@ import re
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import hashlib
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -108,6 +109,7 @@ class QingjianAPIClient:
|
||||
# 多进程分片支持
|
||||
self.shard_keys: set = set() # 本进程负责的客户 key 集合
|
||||
self.worker_id = int(os.getenv('AI_CS_WORKER_ID', '0'))
|
||||
self.worker_count = max(1, int(os.getenv('AI_CS_WORKER_COUNT', '1')))
|
||||
|
||||
# 初始化 Agent
|
||||
if self.enable_agent:
|
||||
@@ -168,6 +170,22 @@ class QingjianAPIClient:
|
||||
self._customer_locks[key] = asyncio.Lock()
|
||||
return self._customer_locks[key]
|
||||
|
||||
def _is_owned_by_this_worker(self, customer_key: str) -> bool:
|
||||
"""
|
||||
多进程兜底路由:
|
||||
- 若显式分片存在,用显式分片;
|
||||
- 否则按 customer_key 哈希到固定 worker,避免多进程重复处理同一消息。
|
||||
"""
|
||||
if self.shard_keys:
|
||||
return customer_key in self.shard_keys
|
||||
if self.worker_count <= 1:
|
||||
return True
|
||||
try:
|
||||
h = int(hashlib.md5(customer_key.encode("utf-8")).hexdigest()[:8], 16)
|
||||
return (h % self.worker_count) == self.worker_id
|
||||
except Exception:
|
||||
return self.worker_id == 0
|
||||
|
||||
async def _agent_reply_serialized(self, data: dict):
|
||||
"""同客户串行 + 全局并发限制,再执行 agent_reply"""
|
||||
key = self._customer_key(data)
|
||||
@@ -208,12 +226,10 @@ class QingjianAPIClient:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
|
||||
# 多进程分片检查:只处理分配给本进程的客户
|
||||
if self.shard_keys:
|
||||
customer_key = self._customer_key(data)
|
||||
if customer_key not in self.shard_keys:
|
||||
# 不属于本进程的客户,跳过
|
||||
return
|
||||
# 多进程分片检查:确保同一客户只由一个 worker 处理
|
||||
customer_key = self._customer_key(data)
|
||||
if not self._is_owned_by_this_worker(customer_key):
|
||||
return
|
||||
|
||||
timestamp = self.get_time()
|
||||
|
||||
|
||||
@@ -24,9 +24,10 @@ logger = logging.getLogger(__name__)
|
||||
class WorkerProcess:
|
||||
"""工作进程"""
|
||||
|
||||
def __init__(self, worker_id: int, shard_keys: List[str], enable_agent: bool = True):
|
||||
def __init__(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool = True):
|
||||
self.worker_id = worker_id
|
||||
self.shard_keys = shard_keys
|
||||
self.num_workers = max(1, int(num_workers))
|
||||
self.enable_agent = enable_agent
|
||||
self.process = None
|
||||
|
||||
@@ -34,17 +35,18 @@ class WorkerProcess:
|
||||
"""启动工作进程"""
|
||||
self.process = Process(
|
||||
target=self._run,
|
||||
args=(self.worker_id, self.shard_keys, self.enable_agent),
|
||||
args=(self.worker_id, self.shard_keys, self.num_workers, self.enable_agent),
|
||||
name=f"ai-cs-worker-{self.worker_id}"
|
||||
)
|
||||
self.process.start()
|
||||
logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})")
|
||||
|
||||
def _run(self, worker_id: int, shard_keys: List[str], enable_agent: bool):
|
||||
def _run(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool):
|
||||
"""工作进程入口"""
|
||||
try:
|
||||
# 设置进程环境变量
|
||||
os.environ['AI_CS_WORKER_ID'] = str(worker_id)
|
||||
os.environ['AI_CS_WORKER_COUNT'] = str(max(1, int(num_workers)))
|
||||
os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys)
|
||||
|
||||
# 导入并启动 WebSocket 客户端
|
||||
@@ -120,6 +122,7 @@ class Coordinator:
|
||||
worker = WorkerProcess(
|
||||
worker_id=worker_id,
|
||||
shard_keys=shards.get(worker_id, []),
|
||||
num_workers=self.num_workers,
|
||||
enable_agent=self.enable_agent
|
||||
)
|
||||
worker.start()
|
||||
|
||||
Reference in New Issue
Block a user