import os import pickle import signal import sys import time import weakref from dataclasses import dataclass from enum import Enum, auto from functools import partial from multiprocessing.process import BaseProcess from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cloudpickle import psutil import zmq from vllm.config import VllmConfig from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 class MultiprocExecutor(Executor): def __init__(self, vllm_config: VllmConfig) -> None: # Call self.shutdown at exit to clean up # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) # The child processes will send SIGUSR1 when unrecoverable # errors happen. def sigusr1_handler(signum, frame): logger.fatal( "MulitprocExecutor got fatal signal from worker processes, " "shutting down. See stack trace above for root cause issue.") # Propagate error up to parent process. parent_process = psutil.Process().parent() parent_process.send_signal(signal.SIGUSR1) self.shutdown() signal.signal(signal.SIGUSR1, sigusr1_handler) self.vllm_config = vllm_config self.parallel_config = vllm_config.parallel_config self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size assert self.world_size == tensor_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " f"tensor_parallel_size ({tensor_parallel_size}). " f"Pipeline parallelism is not yet implemented in v1") # Set multiprocessing envs that are common to V0 and V1 set_multiprocessing_worker_envs(self.parallel_config) # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( "127.0.0.1", get_open_port()) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers self.workers: List[WorkerProcHandle] = [] for rank in range(self.world_size): worker = WorkerProc.make_worker_process(vllm_config, rank, rank, distributed_init_method, scheduler_output_handle) self.workers.append(worker) # Ensure message queues are ready. Will deadlock if re-ordered # Must be kept consistent with the WorkerProc self.rpc_broadcast_mq.wait_until_ready() for w in self.workers: w.worker_response_mq.wait_until_ready() def initialize(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ self.collective_rpc("initialize_cache", args=(kv_cache_config, )) self.collective_rpc("compile_or_warm_up_model") def determine_available_memory(self) -> int: """ Determine the available memory (in bytes) for KV cache by invoking the underlying worker. """ memory_sizes = self.collective_rpc("determine_available_memory") # Since we use a shared centralized controller, we take the minimum # memory size across all workers to make sure all the memory # operators can be applied to all workers. return min(memory_sizes) def get_kv_cache_spec(self) -> KVCacheSpec: """ Get all kv cache needed by the model by invoking the underlying worker. """ kv_cache_specs = self.collective_rpc("get_kv_cache_spec") assert all(s == kv_cache_specs[0] for s in kv_cache_specs) return kv_cache_specs[0] def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: """ Execute an RPC call on workers. Args: method: Name of the worker method to execute timeout: Maximum time in seconds to wait for execution. Rases a TimeoutError on timeout. None means wait indefinitely. args: Positional arguments to pass to the worker method kwargs: Keyword arguments to pass to the worker method Returns: List of results from each worker """ start_time = time.monotonic() kwargs = kwargs or {} try: if isinstance(method, str): send_method = method else: send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) responses = [None] * self.world_size for w in self.workers: dequeue_timeout = timeout - (time.monotonic() - start_time ) if timeout is not None else None status, result = w.worker_response_mq.dequeue( timeout=dequeue_timeout) if status != WorkerProc.ResponseStatus.SUCCESS: if isinstance(result, Exception): raise result else: raise RuntimeError("Worker failed") responses[w.rank] = result return responses except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e except Exception as e: # Re-raise any other exceptions raise e def execute_model( self, scheduler_output, ) -> ModelRunnerOutput: model_output = self.collective_rpc("execute_model", args=(scheduler_output, ))[0] return model_output def profile(self, is_start: bool = True): self.collective_rpc("profile", args=(is_start, )) return def _ensure_worker_termination(self): """Ensure that all worker processes are terminated. Assumes workers have received termination requests. Waits for processing, then sends termination and kill signals if needed.""" def wait_for_termination(procs, timeout): if not time: # If we are in late stage shutdown, the interpreter may replace # `time` with `None`. return all(not proc.is_alive() for proc in procs) start_time = time.time() while time.time() - start_time < timeout: if all(not proc.is_alive() for proc in procs): return True time.sleep(0.1) return False # Send SIGTERM if still running active_procs = [w.proc for w in self.workers if w.proc.is_alive()] for p in active_procs: p.terminate() if not wait_for_termination(active_procs, 4): # Send SIGKILL if still running active_procs = [p for p in active_procs if p.is_alive()] for p in active_procs: p.kill() self._cleanup_sockets() def _cleanup_sockets(self): for w in self.workers: # Remove the zmq ipc socket file socket_path = w.ready_path.replace("ipc://", "") if os and os.path.exists(socket_path): os.remove(socket_path) def shutdown(self): """Properly shut down the executor and its workers""" if getattr(self, 'shutting_down', False): self.shutting_down = True for w in self.workers: w.worker_response_mq = None self._ensure_worker_termination() self.rpc_broadcast_mq = None def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return @dataclass class WorkerProcHandle: proc: BaseProcess rank: int ready_path: str worker_response_mq: MessageQueue # The worker process writes to this MQ class WorkerProc: """Wrapper that runs one Worker in a separate process.""" READY_STR = "READY" def __init__( self, vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, input_shm_handle: Handle, ready_path: str, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) # TODO: move `init_worker` to executor level as a collective rpc call all_kwargs: List[Dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, } wrapper.init_worker(all_kwargs) self.worker = wrapper.worker pid = os.getpid() _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid) # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( input_shm_handle, self.worker.rank) # Initializes a message queue for sending the model output self.worker_response_mq = MessageQueue(1, 1) worker_response_mq_handle = self.worker_response_mq.export_handle() # Send Readiness signal to EngineCore process. with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket: payload = pickle.dumps(worker_response_mq_handle, protocol=pickle.HIGHEST_PROTOCOL) ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send(payload) self.worker.init_device() self.worker.load_model() @staticmethod def make_worker_process( vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, input_shm_handle, # Receive SchedulerOutput ) -> WorkerProcHandle: context = get_mp_context() # ZMQ path for worker to send ready message and shm_broadcast handle # back to core process. ready_path = get_open_zmq_ipc_path() process_kwargs = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, "ready_path": ready_path, } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, kwargs=process_kwargs, daemon=True) proc.start() # Wait for startup worker_response_mq_handle = WorkerProc.wait_for_startup( proc, ready_path) worker_response_mq = MessageQueue.create_from_handle( worker_response_mq_handle, 0) return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) def shutdown(self): self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() destroy_distributed_environment() @staticmethod def worker_main(*args, **kwargs): """ Worker initialization and execution loops. This runs a background process """ # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker # processes to terminate without error shutdown_requested = False def signal_handler(signum, frame): nonlocal shutdown_requested if not shutdown_requested: shutdown_requested = True raise SystemExit() # Either SIGTERM or SIGINT will terminate the worker signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) worker = None try: worker = WorkerProc(*args, **kwargs) # Ensure message queues are ready. Will deadlock if re-ordered. # Must be kept consistent with the Executor worker.rpc_broadcast_mq.wait_until_ready() worker.worker_response_mq.wait_until_ready() worker.worker_busy_loop() except SystemExit: logger.debug("Worker interrupted.") except Exception: # worker_busy_loop sends exceptions exceptons to Executor # for shutdown, but if there is an error in startup or an # error with IPC itself, we need to alert the parent. psutil.Process().parent().send_signal(signal.SIGUSR1) raise finally: # Clean up once worker exits busy loop if worker is not None: worker.shutdown() worker = None @staticmethod def wait_for_startup( proc: BaseProcess, ready_path: str, ) -> Optional[Handle]: """Wait until the Worker is ready.""" with zmq_socket_ctx(ready_path, zmq.constants.PULL) as socket: # Wait for Worker to send READY. while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: logger.debug("Waiting for WorkerProc to startup.") if not proc.is_alive(): raise RuntimeError("WorkerProc failed to start.") message = socket.recv_string() assert message == WorkerProc.READY_STR handle_frame = socket.recv(copy=False) handle = pickle.loads(handle_frame.buffer) return handle class ResponseStatus(Enum): SUCCESS = auto() FAILURE = auto() def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: method, args, kwargs = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) output = func(*args, **kwargs) except Exception as e: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.FAILURE, e)) logger.exception("WorkerProc hit an exception: %s", exc_info=e) continue self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.SUCCESS, output))