fix: prevent duplicate replies in multi-worker routing
This commit is contained in:
@@ -5,6 +5,7 @@ import re
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import hashlib
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -108,6 +109,7 @@ class QingjianAPIClient:
|
|||||||
# 多进程分片支持
|
# 多进程分片支持
|
||||||
self.shard_keys: set = set() # 本进程负责的客户 key 集合
|
self.shard_keys: set = set() # 本进程负责的客户 key 集合
|
||||||
self.worker_id = int(os.getenv('AI_CS_WORKER_ID', '0'))
|
self.worker_id = int(os.getenv('AI_CS_WORKER_ID', '0'))
|
||||||
|
self.worker_count = max(1, int(os.getenv('AI_CS_WORKER_COUNT', '1')))
|
||||||
|
|
||||||
# 初始化 Agent
|
# 初始化 Agent
|
||||||
if self.enable_agent:
|
if self.enable_agent:
|
||||||
@@ -168,6 +170,22 @@ class QingjianAPIClient:
|
|||||||
self._customer_locks[key] = asyncio.Lock()
|
self._customer_locks[key] = asyncio.Lock()
|
||||||
return self._customer_locks[key]
|
return self._customer_locks[key]
|
||||||
|
|
||||||
|
def _is_owned_by_this_worker(self, customer_key: str) -> bool:
|
||||||
|
"""
|
||||||
|
多进程兜底路由:
|
||||||
|
- 若显式分片存在,用显式分片;
|
||||||
|
- 否则按 customer_key 哈希到固定 worker,避免多进程重复处理同一消息。
|
||||||
|
"""
|
||||||
|
if self.shard_keys:
|
||||||
|
return customer_key in self.shard_keys
|
||||||
|
if self.worker_count <= 1:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
h = int(hashlib.md5(customer_key.encode("utf-8")).hexdigest()[:8], 16)
|
||||||
|
return (h % self.worker_count) == self.worker_id
|
||||||
|
except Exception:
|
||||||
|
return self.worker_id == 0
|
||||||
|
|
||||||
async def _agent_reply_serialized(self, data: dict):
|
async def _agent_reply_serialized(self, data: dict):
|
||||||
"""同客户串行 + 全局并发限制,再执行 agent_reply"""
|
"""同客户串行 + 全局并发限制,再执行 agent_reply"""
|
||||||
key = self._customer_key(data)
|
key = self._customer_key(data)
|
||||||
@@ -208,12 +226,10 @@ class QingjianAPIClient:
|
|||||||
try:
|
try:
|
||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
|
|
||||||
# 多进程分片检查:只处理分配给本进程的客户
|
# 多进程分片检查:确保同一客户只由一个 worker 处理
|
||||||
if self.shard_keys:
|
customer_key = self._customer_key(data)
|
||||||
customer_key = self._customer_key(data)
|
if not self._is_owned_by_this_worker(customer_key):
|
||||||
if customer_key not in self.shard_keys:
|
return
|
||||||
# 不属于本进程的客户,跳过
|
|
||||||
return
|
|
||||||
|
|
||||||
timestamp = self.get_time()
|
timestamp = self.get_time()
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ logger = logging.getLogger(__name__)
|
|||||||
class WorkerProcess:
|
class WorkerProcess:
|
||||||
"""工作进程"""
|
"""工作进程"""
|
||||||
|
|
||||||
def __init__(self, worker_id: int, shard_keys: List[str], enable_agent: bool = True):
|
def __init__(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool = True):
|
||||||
self.worker_id = worker_id
|
self.worker_id = worker_id
|
||||||
self.shard_keys = shard_keys
|
self.shard_keys = shard_keys
|
||||||
|
self.num_workers = max(1, int(num_workers))
|
||||||
self.enable_agent = enable_agent
|
self.enable_agent = enable_agent
|
||||||
self.process = None
|
self.process = None
|
||||||
|
|
||||||
@@ -34,17 +35,18 @@ class WorkerProcess:
|
|||||||
"""启动工作进程"""
|
"""启动工作进程"""
|
||||||
self.process = Process(
|
self.process = Process(
|
||||||
target=self._run,
|
target=self._run,
|
||||||
args=(self.worker_id, self.shard_keys, self.enable_agent),
|
args=(self.worker_id, self.shard_keys, self.num_workers, self.enable_agent),
|
||||||
name=f"ai-cs-worker-{self.worker_id}"
|
name=f"ai-cs-worker-{self.worker_id}"
|
||||||
)
|
)
|
||||||
self.process.start()
|
self.process.start()
|
||||||
logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})")
|
logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})")
|
||||||
|
|
||||||
def _run(self, worker_id: int, shard_keys: List[str], enable_agent: bool):
|
def _run(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool):
|
||||||
"""工作进程入口"""
|
"""工作进程入口"""
|
||||||
try:
|
try:
|
||||||
# 设置进程环境变量
|
# 设置进程环境变量
|
||||||
os.environ['AI_CS_WORKER_ID'] = str(worker_id)
|
os.environ['AI_CS_WORKER_ID'] = str(worker_id)
|
||||||
|
os.environ['AI_CS_WORKER_COUNT'] = str(max(1, int(num_workers)))
|
||||||
os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys)
|
os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys)
|
||||||
|
|
||||||
# 导入并启动 WebSocket 客户端
|
# 导入并启动 WebSocket 客户端
|
||||||
@@ -120,6 +122,7 @@ class Coordinator:
|
|||||||
worker = WorkerProcess(
|
worker = WorkerProcess(
|
||||||
worker_id=worker_id,
|
worker_id=worker_id,
|
||||||
shard_keys=shards.get(worker_id, []),
|
shard_keys=shards.get(worker_id, []),
|
||||||
|
num_workers=self.num_workers,
|
||||||
enable_agent=self.enable_agent
|
enable_agent=self.enable_agent
|
||||||
)
|
)
|
||||||
worker.start()
|
worker.start()
|
||||||
|
|||||||
Reference in New Issue
Block a user