Files
tw/legacy/scripts/multi_process_launcher.py

210 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
多进程异步并行启动器
按客户 ID hash 分配到不同进程,实现真正的并行处理
"""
import os
import sys
import signal
import logging
from multiprocessing import Process, cpu_count
from typing import List, Dict
import hashlib
# 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s: %(message)s'
)
logger = logging.getLogger(__name__)
class WorkerProcess:
"""工作进程"""
def __init__(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool = True):
self.worker_id = worker_id
self.shard_keys = shard_keys
self.num_workers = max(1, int(num_workers))
self.enable_agent = enable_agent
self.process = None
def start(self):
"""启动工作进程"""
self.process = Process(
target=self._run,
args=(self.worker_id, self.shard_keys, self.num_workers, self.enable_agent),
name=f"ai-cs-worker-{self.worker_id}"
)
self.process.start()
logger.info(f"Worker {self.worker_id} 启动 (PID: {self.process.pid})")
def _run(self, worker_id: int, shard_keys: List[str], num_workers: int, enable_agent: bool):
"""工作进程入口"""
try:
# 设置进程环境变量
os.environ['AI_CS_WORKER_ID'] = str(worker_id)
os.environ['AI_CS_WORKER_COUNT'] = str(max(1, int(num_workers)))
os.environ['AI_CS_SHARD_KEYS'] = ','.join(shard_keys)
# 导入并启动 WebSocket 客户端
from core.websocket_client import QingjianAPIClient
logger.info(f"Worker {worker_id} 初始化 Agent...")
client = QingjianAPIClient(enable_agent=enable_agent)
# 只处理分配给这个 worker 的客户
client.shard_keys = set(shard_keys)
logger.info(f"Worker {worker_id} 开始处理消息...")
import asyncio
asyncio.run(client.connect())
except KeyboardInterrupt:
logger.info(f"Worker {worker_id} 收到退出信号")
except Exception as e:
logger.error(f"Worker {worker_id} 异常:{e}")
import traceback
traceback.print_exc()
def stop(self):
"""停止工作进程"""
if self.process and self.process.is_alive():
self.process.terminate()
self.process.join(timeout=5)
logger.info(f"Worker {self.worker_id} 已停止")
class Coordinator:
"""协调器 - 管理多个工作进程"""
def __init__(self, num_workers: int = None, enable_agent: bool = True):
self.num_workers = num_workers or max(2, cpu_count())
self.workers: List[WorkerProcess] = []
self.running = False
self._stopping = False
self.enable_agent = enable_agent
def _get_shard_key(self, acc_id: str, from_id: str) -> int:
"""根据店铺 ID + 客户 ID 计算分片 key"""
key = f"{acc_id}:{from_id}"
hash_value = int(hashlib.md5(key.encode()).hexdigest(), 16)
return hash_value % self.num_workers
def _load_customer_shards(self) -> Dict[int, List[str]]:
"""加载客户分片信息
Returns:
{shard_id: [customer_key1, customer_key2, ...]}
"""
# 从数据库或配置文件加载客户列表
# 这里简化处理,实际应该从数据库加载活跃客户
shards = {i: [] for i in range(self.num_workers)}
# TODO: 从数据库加载活跃客户列表
# customers = db.query(...).all()
# for customer in customers:
# shard_id = self._get_shard_key(customer.acc_id, customer.from_id)
# shards[shard_id].append(f"{customer.acc_id}:{customer.from_id}")
logger.info(f"已加载 {sum(len(v) for v in shards.values())} 个客户分片")
return shards
def start(self):
"""启动所有工作进程"""
logger.info(f"启动协调器,工作进程数:{self.num_workers}")
shards = self._load_customer_shards()
# 启动工作进程
for worker_id in range(self.num_workers):
worker = WorkerProcess(
worker_id=worker_id,
shard_keys=shards.get(worker_id, []),
num_workers=self.num_workers,
enable_agent=self.enable_agent
)
worker.start()
self.workers.append(worker)
self.running = True
# 注册信号处理
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
# 监控工作进程
self._monitor_workers()
def _monitor_workers(self):
"""监控工作进程健康状态"""
import time
while self.running:
# 检查工作进程是否存活
for worker in self.workers:
if worker.process and not worker.process.is_alive():
logger.warning(f"Worker {worker.worker_id} 已退出,尝试重启...")
# 重启工作进程
worker.start()
time.sleep(10) # 每 10 秒检查一次
def _signal_handler(self, signum, frame):
"""信号处理"""
if self._stopping:
return
self._stopping = True
logger.info(f"收到信号 {signum},正在停止所有工作进程...")
self.stop()
def stop(self):
"""停止所有工作进程"""
if self._stopping and not self.running and not any(w.process and w.process.is_alive() for w in self.workers):
return
self._stopping = True
self.running = False
for worker in self.workers:
worker.stop()
logger.info("所有工作进程已停止")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='AI 客服多进程启动器')
parser.add_argument(
'--workers',
type=int,
default=None,
help='工作进程数默认CPU 核心数)'
)
args = parser.parse_args()
logger.info("=" * 60)
logger.info("AI 客服系统 - 多进程异步并行模式")
logger.info("=" * 60)
coordinator = Coordinator(num_workers=args.workers)
try:
coordinator.start()
except KeyboardInterrupt:
logger.info("收到退出信号")
coordinator.stop()
except Exception as e:
logger.error(f"启动失败:{e}")
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()