Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -226,8 +226,6 @@ class EngineCoreRequestType(enum.Enum):
|
||||
UTILITY = b"\x03"
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b"\x04"
|
||||
# Sentinel to wake up input_queue.get() during shutdown.
|
||||
WAKEUP = b"\x05"
|
||||
|
||||
|
||||
class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
|
||||
@@ -264,15 +264,16 @@ class AsyncLLM(EngineClient):
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
def shutdown(self):
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if renderer := getattr(self, "renderer", None):
|
||||
renderer.shutdown()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown(timeout=timeout)
|
||||
engine_core.shutdown()
|
||||
|
||||
handler = getattr(self, "output_handler", None)
|
||||
if handler is not None:
|
||||
|
||||
@@ -104,10 +104,8 @@ class DPCoordinator:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown coordinator process with configurable timeout."""
|
||||
if self._finalizer.detach() is not None:
|
||||
shutdown([self.proc], timeout=timeout)
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class EngineState:
|
||||
|
||||
@@ -9,7 +9,6 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
@@ -62,7 +61,6 @@ from vllm.v1.engine import (
|
||||
from vllm.v1.engine.utils import (
|
||||
EngineHandshakeMetadata,
|
||||
EngineZmqAddresses,
|
||||
SignalCallback,
|
||||
get_device_indices,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
@@ -773,12 +771,6 @@ class EngineCore:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EngineShutdownState(IntEnum):
|
||||
RUNNING = 0
|
||||
REQUESTED = 1
|
||||
SHUTTING_DOWN = 2
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
@@ -806,7 +798,6 @@ class EngineCoreProc(EngineCore):
|
||||
self.engine_index = engine_index
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
self.shutdown_state = EngineShutdownState.RUNNING
|
||||
|
||||
with self._perform_handshakes(
|
||||
handshake_address,
|
||||
@@ -1037,11 +1028,25 @@ class EngineCoreProc(EngineCore):
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: EngineCoreProc | None = None
|
||||
signal_callback: SignalCallback | None = None
|
||||
try:
|
||||
vllm_config: VllmConfig = kwargs["vllm_config"]
|
||||
parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||
@@ -1089,22 +1094,6 @@ class EngineCoreProc(EngineCore):
|
||||
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
|
||||
|
||||
assert engine_core is not None
|
||||
|
||||
def wakeup_engine():
|
||||
# Wakes up idle engine via input_queue when shutdown is requested
|
||||
# Not safe in a signal handler - we may interrupt the main thread
|
||||
# while it is holding the non-reentrant input_queue.mutex
|
||||
engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))
|
||||
|
||||
signal_callback = SignalCallback(wakeup_engine)
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
engine_core.shutdown_state = EngineShutdownState.REQUESTED
|
||||
signal_callback.trigger()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
@@ -1118,10 +1107,6 @@ class EngineCoreProc(EngineCore):
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
if signal_callback is not None:
|
||||
signal_callback.stop()
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
@@ -1136,25 +1121,21 @@ class EngineCoreProc(EngineCore):
|
||||
or bool(self.batch_queue)
|
||||
)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Returns true if shutdown has not been requested."""
|
||||
return self.shutdown_state == EngineShutdownState.RUNNING
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
while self._handle_shutdown():
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
raise SystemExit
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.has_work() and self.is_running():
|
||||
while not self.has_work():
|
||||
# Notify callbacks waiting for engine to become idle.
|
||||
self._notify_idle_state_callbacks()
|
||||
if self.input_queue.empty():
|
||||
@@ -1206,60 +1187,18 @@ class EngineCoreProc(EngineCore):
|
||||
callback = self._idle_state_callbacks.pop()
|
||||
callback(self)
|
||||
|
||||
def _handle_shutdown(self) -> bool:
|
||||
# Check if shutdown was requested and handle it
|
||||
if self.shutdown_state == EngineShutdownState.RUNNING:
|
||||
return True
|
||||
|
||||
if self.shutdown_state == EngineShutdownState.REQUESTED:
|
||||
shutdown_timeout = self.vllm_config.shutdown_timeout
|
||||
|
||||
logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout)
|
||||
|
||||
if shutdown_timeout == 0:
|
||||
num_requests = self.scheduler.get_num_unfinished_requests()
|
||||
if num_requests > 0:
|
||||
logger.info("Aborting %d requests", num_requests)
|
||||
aborted_reqs = self.scheduler.finish_requests(
|
||||
None, RequestStatus.FINISHED_ABORTED
|
||||
)
|
||||
self._send_abort_outputs(aborted_reqs)
|
||||
else:
|
||||
num_requests = self.scheduler.get_num_unfinished_requests()
|
||||
if num_requests > 0:
|
||||
logger.info(
|
||||
"Draining %d in-flight requests (timeout=%ds)",
|
||||
num_requests,
|
||||
shutdown_timeout,
|
||||
)
|
||||
|
||||
self.shutdown_state = EngineShutdownState.SHUTTING_DOWN
|
||||
|
||||
# Exit when no work remaining
|
||||
if not self.has_work():
|
||||
logger.info("Shutdown complete")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_client_request(
|
||||
self, request_type: EngineCoreRequestType, request: Any
|
||||
) -> None:
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if request_type == EngineCoreRequestType.WAKEUP:
|
||||
return
|
||||
elif request_type == EngineCoreRequestType.ADD:
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
req, request_wave = request
|
||||
if self._reject_add_in_shutdown(req):
|
||||
return
|
||||
self.add_request(req, request_wave)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
client_idx, call_id, method_name, args = request
|
||||
if self._reject_utility_in_shutdown(client_idx, call_id, method_name):
|
||||
return
|
||||
output = UtilityOutput(call_id)
|
||||
# Lazily look-up utility method so that failure will be handled/returned.
|
||||
get_result = lambda: (method := getattr(self, method_name)) and method(
|
||||
@@ -1276,27 +1215,6 @@ class EngineCoreProc(EngineCore):
|
||||
"Unrecognized input request type encountered: %s", request_type
|
||||
)
|
||||
|
||||
def _reject_add_in_shutdown(self, request: Request) -> bool:
|
||||
if self.shutdown_state == EngineShutdownState.RUNNING:
|
||||
return False
|
||||
|
||||
logger.info("Rejecting request %s (server shutting down)", request.request_id)
|
||||
self._send_abort_outputs_to_client([request.request_id], request.client_index)
|
||||
return True
|
||||
|
||||
def _reject_utility_in_shutdown(
|
||||
self, client_idx: int, call_id: int, method_name: str
|
||||
) -> bool:
|
||||
if self.shutdown_state == EngineShutdownState.RUNNING:
|
||||
return False
|
||||
|
||||
logger.warning("Rejecting utility call %s (server shutting down)", method_name)
|
||||
output = UtilityOutput(call_id, failure_message="Server shutting down")
|
||||
self.output_queue.put_nowait(
|
||||
(client_idx, EngineCoreOutputs(utility_output=output))
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _invoke_utility_method(
|
||||
name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
|
||||
@@ -1510,7 +1428,22 @@ class EngineCoreProc(EngineCore):
|
||||
logger.exception(
|
||||
"Unexpected error pre-processing request %s", request.request_id
|
||||
)
|
||||
self._send_error_outputs_to_client([request.request_id], request.client_index)
|
||||
self.output_queue.put_nowait(
|
||||
(
|
||||
request.client_index,
|
||||
EngineCoreOutputs(
|
||||
engine_index=self.engine_index,
|
||||
finished_requests={request.request_id},
|
||||
outputs=[
|
||||
EngineCoreOutput(
|
||||
request_id=request.request_id,
|
||||
new_token_ids=[],
|
||||
finish_reason=FinishReason.ERROR,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def pause_scheduler(
|
||||
self, mode: PauseMode = "abort", clear_cache: bool = True
|
||||
@@ -1553,26 +1486,6 @@ class EngineCoreProc(EngineCore):
|
||||
self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
|
||||
return future
|
||||
|
||||
def _send_finish_outputs_to_client(
|
||||
self, req_ids: list[str], client_index: int, finish_reason: FinishReason
|
||||
) -> None:
|
||||
outputs = [
|
||||
EngineCoreOutput(req_id, [], finish_reason=finish_reason)
|
||||
for req_id in req_ids
|
||||
]
|
||||
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
|
||||
self.output_queue.put_nowait((client_index, eco))
|
||||
|
||||
def _send_abort_outputs_to_client(
|
||||
self, req_ids: list[str], client_index: int
|
||||
) -> None:
|
||||
self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ABORT)
|
||||
|
||||
def _send_error_outputs_to_client(
|
||||
self, req_ids: list[str], client_index: int
|
||||
) -> None:
|
||||
self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ERROR)
|
||||
|
||||
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
|
||||
# TODO(nick) this will be moved inside the scheduler
|
||||
if aborted_reqs:
|
||||
@@ -1581,7 +1494,12 @@ class EngineCoreProc(EngineCore):
|
||||
for req_id, client_index in aborted_reqs:
|
||||
by_client[client_index].add(req_id)
|
||||
for client_index, req_ids in by_client.items():
|
||||
self._send_abort_outputs_to_client(list(req_ids), client_index)
|
||||
outputs = [
|
||||
EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
|
||||
for req_id in req_ids
|
||||
]
|
||||
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
|
||||
self.output_queue.put_nowait((client_index, eco))
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
@@ -1699,7 +1617,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while self._handle_shutdown():
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
@@ -1747,8 +1665,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.current_wave += 1
|
||||
self.step_counter = 0
|
||||
|
||||
raise SystemExit
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||
self.step_counter += 1
|
||||
|
||||
@@ -128,7 +128,7 @@ class EngineCoreClient(ABC):
|
||||
return AsyncMPClient(*client_args)
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self, timeout: float | None = None) -> None: ...
|
||||
def shutdown(self): ...
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
@@ -298,7 +298,7 @@ class InprocClient(EngineCoreClient):
|
||||
if len(request_ids) > 0:
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
def shutdown(self) -> None:
|
||||
self.engine_core.shutdown()
|
||||
|
||||
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
|
||||
@@ -390,9 +390,9 @@ class BackgroundResources:
|
||||
|
||||
self.engine_dead = True
|
||||
if self.engine_manager is not None:
|
||||
self.engine_manager.shutdown()
|
||||
self.engine_manager.close()
|
||||
if self.coordinator is not None:
|
||||
self.coordinator.shutdown()
|
||||
self.coordinator.close()
|
||||
|
||||
if isinstance(self.output_socket, zmq.asyncio.Socket):
|
||||
# Async case.
|
||||
@@ -568,7 +568,10 @@ class MPClient(EngineCoreClient):
|
||||
)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, addresses
|
||||
vllm_config,
|
||||
executor_class,
|
||||
log_stats,
|
||||
addresses,
|
||||
) as (engine_manager, coordinator, addresses):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
@@ -634,12 +637,9 @@ class MPClient(EngineCoreClient):
|
||||
if not success:
|
||||
self._finalizer()
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown engine manager under timeout and clean up resources."""
|
||||
if self._finalizer.detach() is not None:
|
||||
if self.resources.engine_manager is not None:
|
||||
self.resources.engine_manager.shutdown(timeout=timeout)
|
||||
self.resources()
|
||||
def shutdown(self):
|
||||
# Terminate background resources.
|
||||
self._finalizer()
|
||||
|
||||
def _format_exception(self, e: Exception) -> Exception:
|
||||
"""If errored, use EngineDeadError so root cause is clear."""
|
||||
@@ -683,7 +683,7 @@ class MPClient(EngineCoreClient):
|
||||
sentinels = [proc.sentinel for proc in engine_processes]
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
|
||||
if not _self or _self.resources.engine_dead:
|
||||
return
|
||||
_self.resources.engine_dead = True
|
||||
proc_name = next(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import threading
|
||||
import weakref
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass
|
||||
@@ -152,12 +151,11 @@ class CoreEngineProcManager:
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
self.shutdown()
|
||||
self.close()
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown engine core processes with configurable timeout."""
|
||||
if self._finalizer.detach() is not None:
|
||||
shutdown(self.processes, timeout=timeout)
|
||||
def close(self):
|
||||
"""Shutdown all procs."""
|
||||
self._finalizer()
|
||||
|
||||
def join_first(self):
|
||||
"""Wait for any process to exit."""
|
||||
@@ -175,33 +173,6 @@ class CoreEngineProcManager:
|
||||
}
|
||||
|
||||
|
||||
class SignalCallback:
|
||||
"""Safely trigger a callback from signal handler context via a dedicated thread."""
|
||||
|
||||
def __init__(self, callback: Callable[[], None]):
|
||||
self._callback = callback
|
||||
self._event = threading.Event()
|
||||
self._stopped = False
|
||||
self._thread = threading.Thread(
|
||||
target=self._run,
|
||||
daemon=True,
|
||||
name="signal-callback",
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _run(self):
|
||||
self._event.wait()
|
||||
if not self._stopped:
|
||||
self._callback()
|
||||
|
||||
def trigger(self):
|
||||
self._event.set()
|
||||
|
||||
def stop(self):
|
||||
self._stopped = True
|
||||
self._event.set()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_device_control_env_var(
|
||||
vllm_config: VllmConfig, local_dp_rank: int
|
||||
@@ -797,7 +768,7 @@ class CoreEngineActorManager:
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
def close(self):
|
||||
import ray
|
||||
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
|
||||
@@ -220,10 +220,8 @@ class APIServerProcessManager:
|
||||
# The extra processes are managed by their owners
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown API server processes with configurable timeout"""
|
||||
if self._finalizer.detach() is not None:
|
||||
shutdown(self.processes, timeout=timeout)
|
||||
def close(self) -> None:
|
||||
self._finalizer()
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
@@ -290,30 +288,25 @@ def wait_for_completion_or_failure(
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while running API servers: %s", str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info("Terminating remaining processes ...")
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if engine_manager:
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
# else the gc cannot collect the object.
|
||||
def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
|
||||
"""Shutdown processes with timeout.
|
||||
|
||||
Args:
|
||||
procs: List of processes to shutdown
|
||||
timeout: Maximum time in seconds to wait for graceful shutdown
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = 0.0
|
||||
|
||||
# Allow at least 5 seconds for remaining procs to terminate.
|
||||
timeout = max(timeout, 5.0)
|
||||
|
||||
def shutdown(procs: list[BaseProcess]):
|
||||
# Shutdown the process.
|
||||
for proc in procs:
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
|
||||
# Allow time for remaining procs to terminate.
|
||||
deadline = time.monotonic() + timeout
|
||||
# Allow 5 seconds for remaining procs to terminate.
|
||||
deadline = time.monotonic() + 5
|
||||
for proc in procs:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
|
||||
Reference in New Issue
Block a user