diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 1ece3e4df..e265a088a 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -225,7 +225,7 @@ def run_headless(args: argparse.Namespace): ) try: - engine_manager.join_first() + engine_manager.monitor_engine_liveness() finally: timeout = None if shutdown_requested: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b9a3c7545..1d73c12ed 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib -import multiprocessing import queue import sys import uuid @@ -640,34 +639,20 @@ class MPClient(EngineCoreClient): def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager - if ( - engine_manager is None - or not hasattr(engine_manager, "processes") - or not engine_manager.processes - ): + if engine_manager is None: # No engine processes to monitor return - engine_processes = engine_manager.processes self_ref = weakref.ref(self) # Monitor engine core process liveness. If any die unexpectedly, - # logs an error, shuts down the client and invokes the failure - # callback to inform the engine. + # marks the engine as dead, and shuts down the client. def monitor_engine_cores(): - sentinels = [proc.sentinel for proc in engine_processes] - died = multiprocessing.connection.wait(sentinels) + engine_manager.monitor_engine_liveness() _self = self_ref() if not _self or not _self._finalizer.alive or _self.resources.engine_dead: return _self.resources.engine_dead = True - proc_name = next( - proc.name for proc in engine_processes if proc.sentinel == died[0] - ) - logger.error( - "Engine core proc %s died unexpectedly, shutting down client.", - proc_name, - ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will @@ -1634,6 +1619,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient): parallel_config = self.vllm_config.parallel_config ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap() + removed_dp_size = cur_data_parallel_size - new_data_parallel_size + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + self.resources.engine_manager.remove_run_refs_for_scale_down(removed_dp_size) reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): reconfig_request = ReconfigureDistributedRequest( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 90ec47edb..0ce0ed88e 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -11,7 +11,7 @@ from enum import Enum, auto from multiprocessing import Process, connection from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from unittest.mock import patch import msgspec @@ -133,6 +133,8 @@ class CoreEngineProcManager: ) self._finalizer = weakref.finalize(self, shutdown, self.processes) + self.manager_stopped = threading.Event() + self.failed_proc_name: str | None = None try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): @@ -154,12 +156,31 @@ class CoreEngineProcManager: def shutdown(self, timeout: float | None = None) -> None: """Shutdown engine core processes with configurable timeout.""" + self.manager_stopped.set() if self._finalizer.detach() is not None: shutdown(self.processes, timeout=timeout) - def join_first(self): - """Wait for any process to exit.""" - connection.wait(proc.sentinel for proc in self.processes) + def monitor_engine_liveness(self) -> None: + """Monitor engine core process liveness.""" + + sentinel_to_proc = {proc.sentinel: proc for proc in self.processes} + sentinels = set(sentinel_to_proc.keys()) + + while sentinels and not self.manager_stopped.is_set(): + died_sentinels = connection.wait(sentinels, timeout=1) + + for sentinel in died_sentinels: + proc = sentinel_to_proc.pop(cast(int, sentinel)) + exitcode = proc.exitcode + if exitcode != 0 and not self.manager_stopped.is_set(): + self.failed_proc_name = proc.name + if died_sentinels: + # Any engine exit currently triggers a shutdown. Future + # work (e.g., Elastic and fault-tolerant EP) will add finer-grained + # handling for different exit scenarios. + break + + self.shutdown() def sentinels(self) -> list: return [proc.sentinel for proc in self.processes] @@ -298,6 +319,8 @@ class CoreEngineActorManager: self.log_stats = log_stats local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size + self.manager_stopped = threading.Event() + self.failed_proc_name: str | None = None if ray.is_initialized(): logger.info("Ray is already initialized. Skipping Ray initialization.") @@ -395,8 +418,11 @@ class CoreEngineActorManager: ray.get(refs) self.run_refs = [] + self.actor_run_ref_dict = dict() for actor in self.local_engine_actors + self.remote_engine_actors: - self.run_refs.append(actor.run.remote()) + ref = actor.run.remote() + self.run_refs.append(ref) + self.actor_run_ref_dict[actor] = ref @staticmethod def create_dp_placement_groups( @@ -776,7 +802,9 @@ class CoreEngineActorManager: ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :] for actor in actors: - self.run_refs.append(actor.run.remote()) + ref = actor.run.remote() + self.run_refs.append(ref) + self.actor_run_ref_dict[actor] = ref cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Update old_vllm_config with new data_parallel_size_local if any new @@ -805,12 +833,59 @@ class CoreEngineActorManager: self.remote_engine_actors.pop() ray.util.remove_placement_group(pg) + def remove_run_refs_for_scale_down(self, removed_dp_size: int) -> None: + if removed_dp_size <= 0: + return + flags = self.placement_group_is_local[-removed_dp_size:] + li = len(self.local_engine_actors) - 1 + ri = len(self.remote_engine_actors) - 1 + for is_local in reversed(flags): + if is_local: + actor = self.local_engine_actors[li] + li -= 1 + else: + actor = self.remote_engine_actors[ri] + ri -= 1 + ref = self.actor_run_ref_dict.pop(actor) + self.run_refs.remove(ref) + def get_run_refs(self): return self.run_refs + def monitor_engine_liveness(self) -> None: + import ray + + while not self.manager_stopped.is_set(): + actor_run_refs = list(self.get_run_refs()) + if not actor_run_refs: + logger.info( + "There are no actors to monitor currently. " + "The monitoring function is about to terminate." + ) + break + actor_done_refs, _ = ray.wait(actor_run_refs, timeout=5) + unexpected_failure = False + for actor_ref in actor_done_refs: + if self.manager_stopped.is_set(): + break + if actor_ref not in self.get_run_refs(): + # The run refs may have been updated by elastic scale-down. + continue + try: + ray.get(actor_ref) + except ray.exceptions.RayActorError: + self.failed_proc_name = f"Actor {actor_ref}" + unexpected_failure = True + + if unexpected_failure: + break + + self.shutdown() + def shutdown(self, timeout: float | None = None) -> None: import ray + self.manager_stopped.set() for actor in self.local_engine_actors + self.remote_engine_actors: ray.kill(actor) for pg in self.created_placement_groups: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 1aa36b1a5..eb81a3c88 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -3,6 +3,7 @@ import argparse import contextlib import multiprocessing +import threading import time import weakref from collections.abc import Callable, Sequence @@ -269,8 +270,6 @@ def wait_for_completion_or_failure( coordinator: The coordinator for data parallel. """ - from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager - try: logger.info("Waiting for API servers to complete ...") # Create a mapping of sentinels to their corresponding processes @@ -282,33 +281,40 @@ def wait_for_completion_or_failure( if coordinator: sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc - actor_run_refs = [] - if isinstance(engine_manager, CoreEngineProcManager): - for proc in engine_manager.processes: - sentinel_to_proc[proc.sentinel] = proc - elif isinstance(engine_manager, CoreEngineActorManager): - actor_run_refs = engine_manager.get_run_refs() + if engine_manager: + core_shutdown_recv, core_shutdown_send = connection.Pipe(duplex=False) + + def monitor_engines(): + try: + engine_manager.monitor_engine_liveness() + finally: + core_shutdown_send.close() + core_shutdown_recv.close() + + # start monitor for engine liveness + threading.Thread(target=monitor_engines, daemon=True).start() + sentinel_to_proc[core_shutdown_recv] = None # type: ignore[assignment] # Check if any process terminates - while sentinel_to_proc or actor_run_refs: - # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) + while sentinel_to_proc: + # Wait for any process to terminate (or engine shutdown signal) + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) # Process any terminated processes for sentinel in ready_sentinels: proc = sentinel_to_proc.pop(sentinel) # Check if process exited with error - if proc.exitcode != 0: + if proc is not None and proc.exitcode != 0: raise RuntimeError( f"Process {proc.name} (PID: {proc.pid}) " f"died with exit code {proc.exitcode}" ) - - if actor_run_refs: - import ray - - _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) + if engine_manager and engine_manager.failed_proc_name is not None: + raise RuntimeError( + f"Engine core process {engine_manager.failed_proc_name} " + "died unexpectedly." + ) except KeyboardInterrupt: logger.info("Received KeyboardInterrupt, shutting down API servers...")