[Frontend][Core] Re-add shutdown timeout - allowing in-flight requests to finish (#36666)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Mark McLoughlin
2026-03-13 19:10:06 +00:00
committed by GitHub
parent 5a3f1eb62f
commit 7afe0faab1
14 changed files with 762 additions and 96 deletions

View File

@@ -1,14 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for shutdown behavior, timeout, and signal handling."""
import asyncio
import signal
import subprocess
import sys
import time
from dataclasses import dataclass, field
import httpx
import openai
import psutil
import pytest
from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
@@ -18,6 +24,101 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
_IS_ROCM = current_platform.is_rocm()
_SERVER_STARTUP_TIMEOUT = 120
_PROCESS_EXIT_TIMEOUT = 15
_SHUTDOWN_DETECTION_TIMEOUT = 10
_CHILD_CLEANUP_TIMEOUT = 10
def _get_child_pids(parent_pid: int) -> list[int]:
try:
parent = psutil.Process(parent_pid)
return [c.pid for c in parent.children(recursive=True)]
except psutil.NoSuchProcess:
return []
async def _assert_children_cleaned_up(
child_pids: list[int],
timeout: float = _CHILD_CLEANUP_TIMEOUT,
):
"""Wait for child processes to exit and fail if any remain."""
if not child_pids:
return
deadline = time.time() + timeout
while time.time() < deadline:
still_alive = []
for pid in child_pids:
try:
p = psutil.Process(pid)
if p.is_running() and p.status() != psutil.STATUS_ZOMBIE:
still_alive.append(pid)
except psutil.NoSuchProcess:
pass
if not still_alive:
return
await asyncio.sleep(0.5)
pytest.fail(
f"Child processes {still_alive} still alive after {timeout}s. "
f"Process cleanup may not be working correctly."
)
@dataclass
class ShutdownState:
got_503: bool = False
got_500: bool = False
requests_after_sigterm: int = 0
aborted_requests: int = 0
connection_errors: int = 0
stop_requesting: bool = False
errors: list[str] = field(default_factory=list)
async def _concurrent_request_loop(
client: openai.AsyncOpenAI,
state: ShutdownState,
sigterm_sent: asyncio.Event | None = None,
concurrency: int = 10,
):
"""Run multiple concurrent requests to keep the server busy."""
async def single_request():
while not state.stop_requesting:
try:
response = await client.completions.create(
model=MODEL_NAME,
prompt="Write a story: ",
max_tokens=200,
)
if sigterm_sent is not None and sigterm_sent.is_set():
state.requests_after_sigterm += 1
# Check if any choice has finish_reason='abort'
if any(choice.finish_reason == "abort" for choice in response.choices):
state.aborted_requests += 1
except openai.APIStatusError as e:
if e.status_code == 503:
state.got_503 = True
elif e.status_code == 500:
state.got_500 = True
else:
state.errors.append(f"API error: {e}")
except (openai.APIConnectionError, httpx.RemoteProtocolError):
state.connection_errors += 1
if sigterm_sent is not None and sigterm_sent.is_set():
break
except Exception as e:
state.errors.append(f"Unexpected error: {e}")
break
await asyncio.sleep(0.01)
tasks = [asyncio.create_task(single_request()) for _ in range(concurrency)]
try:
await asyncio.gather(*tasks, return_exceptions=True)
finally:
for t in tasks:
if not t.done():
t.cancel()
@pytest.mark.asyncio
@@ -103,3 +204,361 @@ async def test_shutdown_on_engine_failure():
return_code = proc.wait(timeout=_PROCESS_EXIT_TIMEOUT)
assert return_code is not None
@pytest.mark.asyncio
async def test_wait_timeout_completes_requests():
"""Verify wait timeout: new requests rejected, in-flight requests complete."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"30",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
state = ShutdownState()
sigterm_sent = asyncio.Event()
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
)
await asyncio.sleep(0.5)
proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
try:
await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT)
except asyncio.TimeoutError:
pass
finally:
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
# wait timeout should complete in-flight requests
assert state.requests_after_sigterm > 0, (
f"Wait timeout should complete in-flight requests. "
f"503: {state.got_503}, 500: {state.got_500}, "
f"conn_errors: {state.connection_errors}, errors: {state.errors}"
)
# server must stop accepting new requests (503, 500, or connection close)
assert state.got_503 or state.got_500 or state.connection_errors > 0, (
f"Server should stop accepting requests. "
f"completed: {state.requests_after_sigterm}, errors: {state.errors}"
)
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
@pytest.mark.parametrize("wait_for_engine_idle", [0.0, 2.0])
async def test_abort_timeout_exits_quickly(wait_for_engine_idle: float):
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"0",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
if wait_for_engine_idle > 0:
client = remote_server.get_async_client()
# Send requests to ensure engine is fully initialized
for _ in range(2):
await client.completions.create(
model=MODEL_NAME,
prompt="Test request: ",
max_tokens=10,
)
# Wait for engine to become idle
await asyncio.sleep(wait_for_engine_idle)
start_time = time.time()
proc.send_signal(signal.SIGTERM)
# abort timeout (0) should exit promptly
for _ in range(20):
if proc.poll() is not None:
break
time.sleep(0.1)
if proc.poll() is None:
proc.kill()
proc.wait(timeout=5)
pytest.fail("Process did not exit after SIGTERM with abort timeout")
exit_time = time.time() - start_time
assert exit_time < 2, f"Default shutdown took too long: {exit_time:.1f}s"
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_wait_timeout_with_short_duration():
"""Verify server exits cleanly with a short wait timeout."""
wait_timeout = 3
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
str(wait_timeout),
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
state = ShutdownState()
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, concurrency=3)
)
await asyncio.sleep(0.5)
start_time = time.time()
proc.send_signal(signal.SIGTERM)
# server should exit within wait_timeout + buffer
max_wait = wait_timeout + 15
for _ in range(int(max_wait * 10)):
if proc.poll() is not None:
break
time.sleep(0.1)
exit_time = time.time() - start_time
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
if proc.poll() is None:
proc.kill()
proc.wait(timeout=5)
pytest.fail(f"Process did not exit within {max_wait}s after SIGTERM")
assert exit_time < wait_timeout + 10, (
f"Took too long to exit ({exit_time:.1f}s), expected <{wait_timeout + 10}s"
)
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_abort_timeout_fails_inflight_requests():
"""Verify abort timeout (0) immediately aborts in-flight requests."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"0",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
state = ShutdownState()
sigterm_sent = asyncio.Event()
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
)
await asyncio.sleep(0.5)
proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
try:
await asyncio.wait_for(request_task, timeout=5)
except asyncio.TimeoutError:
pass
finally:
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
# With abort timeout (0), requests should be aborted (finish_reason='abort')
# or rejected (connection errors or API errors)
assert (
state.aborted_requests > 0
or state.connection_errors > 0
or state.got_500
or state.got_503
), (
f"Abort timeout should cause request aborts or failures. "
f"aborted: {state.aborted_requests}, "
f"503: {state.got_503}, 500: {state.got_500}, "
f"conn_errors: {state.connection_errors}, "
f"completed: {state.requests_after_sigterm}"
)
# Verify fast shutdown
start_time = time.time()
for _ in range(100):
if proc.poll() is not None:
break
time.sleep(0.1)
exit_time = time.time() - start_time
assert exit_time < 10, f"Abort timeout shutdown took too long: {exit_time:.1f}s"
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_request_rejection_during_shutdown():
"""Verify new requests are rejected with error during shutdown."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"30",
]
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
proc.send_signal(signal.SIGTERM)
await asyncio.sleep(1.0)
# Try to send new requests - they should be rejected
rejected_count = 0
for _ in range(10):
try:
await client.completions.create(
model=MODEL_NAME, prompt="Hello", max_tokens=10
)
except (
openai.APIStatusError,
openai.APIConnectionError,
httpx.RemoteProtocolError,
):
rejected_count += 1
await asyncio.sleep(0.1)
assert rejected_count > 0, (
f"Expected requests to be rejected during shutdown, "
f"but {rejected_count} were rejected out of 10"
)
await _assert_children_cleaned_up(child_pids)
@pytest.mark.asyncio
async def test_multi_api_server_shutdown():
"""Verify shutdown works with multiple API servers."""
server_args = [
"--dtype",
"bfloat16",
"--max-model-len",
"256",
"--enforce-eager",
"--gpu-memory-utilization",
"0.05",
"--max-num-seqs",
"4",
"--shutdown-timeout",
"30",
"--api-server-count",
"2",
]
with RemoteOpenAIServer(MODEL_NAME, server_args, auto_port=True) as remote_server:
client = remote_server.get_async_client()
proc = remote_server.proc
child_pids = _get_child_pids(proc.pid)
assert len(child_pids) >= 2, (
f"Expected at least 2 child processes, got {len(child_pids)}"
)
state = ShutdownState()
sigterm_sent = asyncio.Event()
# Start concurrent requests across both API servers
request_task = asyncio.create_task(
_concurrent_request_loop(client, state, sigterm_sent, concurrency=8)
)
await asyncio.sleep(0.5)
# Send SIGTERM to parent - should propagate to all children
proc.send_signal(signal.SIGTERM)
sigterm_sent.set()
try:
await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT)
except asyncio.TimeoutError:
pass
finally:
state.stop_requesting = True
if not request_task.done():
request_task.cancel()
await asyncio.gather(request_task, return_exceptions=True)
for _ in range(300): # up to 30 seconds
if proc.poll() is not None:
break
time.sleep(0.1)
if proc.poll() is None:
proc.kill()
proc.wait(timeout=5)
pytest.fail("Process did not exit after SIGTERM")
await _assert_children_cleaned_up(child_pids)

View File

@@ -79,7 +79,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
finally:
# Always clean up the processes
print("Cleaning up processes...")
manager.close()
manager.shutdown()
# Give processes time to terminate
time.sleep(0.2)
@@ -111,6 +111,8 @@ def test_wait_for_completion_or_failure(api_server_args):
wait_for_completion_or_failure(api_server_manager=manager)
except Exception as e:
result["exception"] = e
finally:
manager.shutdown()
# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
@@ -143,7 +145,7 @@ def test_wait_for_completion_or_failure(api_server_args):
assert not proc.is_alive(), f"Process {i} should not be alive"
finally:
manager.close()
manager.shutdown()
time.sleep(0.2)
@@ -174,11 +176,14 @@ def test_normal_completion(api_server_args):
# since all processes have already
# terminated, it should return immediately
# with no error
wait_for_completion_or_failure(api_server_manager=manager)
try:
wait_for_completion_or_failure(api_server_manager=manager)
finally:
manager.shutdown()
finally:
# Clean up just in case
manager.close()
manager.shutdown()
time.sleep(0.2)
@@ -201,7 +206,7 @@ def test_external_process_monitoring(api_server_args):
def __init__(self, proc):
self.proc = proc
def close(self):
def shutdown(self):
if self.proc.is_alive():
self.proc.terminate()
self.proc.join(timeout=0.5)
@@ -226,6 +231,9 @@ def test_external_process_monitoring(api_server_args):
)
except Exception as e:
result["exception"] = e
finally:
manager.shutdown()
mock_coordinator.shutdown()
# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
@@ -259,6 +267,6 @@ def test_external_process_monitoring(api_server_args):
finally:
# Clean up
manager.close()
mock_coordinator.close()
manager.shutdown()
mock_coordinator.shutdown()
time.sleep(0.2)

View File

@@ -327,6 +327,12 @@ class VllmConfig:
weight_transfer_config: WeightTransferConfig | None = None
"""The configurations for weight transfer during RL training."""
shutdown_timeout: int = Field(default=0, ge=0)
"""Shutdown grace period for in-flight requests. Shutdown will be delayed for
up to this amount of time to allow already-running requests to complete. Any
remaining requests are aborted once the timeout is reached.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@@ -606,6 +606,8 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False
shutdown_timeout: int = 0
weight_transfer_config: WeightTransferConfig | None = get_field(
VllmConfig,
"weight_transfer_config",
@@ -1308,6 +1310,14 @@ class EngineArgs:
default=False,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
"--shutdown-timeout",
type=int,
default=0,
help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
)
return parser
@classmethod
@@ -1916,6 +1926,7 @@ class EngineArgs:
optimization_level=self.optimization_level,
performance_mode=self.performance_mode,
weight_transfer_config=self.weight_transfer_config,
shutdown_timeout=self.shutdown_timeout,
)
return config

View File

@@ -200,6 +200,11 @@ class EngineClient(ABC):
"""Return whether the engine is currently paused."""
...
@abstractmethod
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown the engine with optional timeout."""
...
async def scale_elastic_ep(
self, new_data_parallel_size: int, drain_timeout: int = 300
) -> None:

View File

@@ -3,6 +3,7 @@
import argparse
import signal
import time
import uvloop
@@ -222,8 +223,12 @@ def run_headless(args: argparse.Namespace):
try:
engine_manager.join_first()
finally:
timeout = None
if shutdown_requested:
timeout = vllm_config.shutdown_timeout
logger.info("Waiting up to %d seconds for processes to exit", timeout)
engine_manager.shutdown(timeout=timeout)
logger.info("Shutting down.")
engine_manager.close()
def run_multi_api_server(args: argparse.Namespace):
@@ -234,6 +239,19 @@ def run_multi_api_server(args: argparse.Namespace):
if num_api_servers > 1:
setup_multiprocess_prometheus()
shutdown_requested = False
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
nonlocal shutdown_requested
logger.debug("Received %d signal.", signum)
if not shutdown_requested:
shutdown_requested = True
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
listen_address, sock = setup_server(args)
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
@@ -295,11 +313,29 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
# Wait for API servers
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator,
)
try:
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator,
)
finally:
timeout = shutdown_by = None
if shutdown_requested:
timeout = vllm_config.shutdown_timeout
shutdown_by = time.monotonic() + timeout
logger.info("Waiting up to %d seconds for processes to exit", timeout)
def to_timeout(deadline: float | None) -> float | None:
return (
deadline if deadline is None else max(deadline - time.monotonic(), 0.0)
)
api_server_manager.shutdown(timeout=timeout)
if local_engine_manager:
local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
if coordinator:
coordinator.shutdown(timeout=to_timeout(shutdown_by))
def run_api_server_worker_proc(

View File

@@ -4,6 +4,7 @@
import asyncio
import signal
import socket
from functools import partial
from typing import Any
import uvicorn
@@ -91,12 +92,10 @@ async def serve_http(
)
)
shutdown_event = asyncio.Event()
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
shutdown_event.set()
async def dummy_shutdown() -> None:
pass
@@ -104,6 +103,24 @@ async def serve_http(
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
async def handle_shutdown() -> None:
await shutdown_event.wait()
engine_client = app.state.engine_client
timeout = engine_client.vllm_config.shutdown_timeout
await loop.run_in_executor(
None, partial(engine_client.shutdown, timeout=timeout)
)
server.should_exit = True
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
shutdown_task = loop.create_task(handle_shutdown())
try:
await server_task
return dummy_shutdown()
@@ -120,6 +137,7 @@ async def serve_http(
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
shutdown_task.cancel()
watchdog_task.cancel()

View File

@@ -226,6 +226,8 @@ class EngineCoreRequestType(enum.Enum):
UTILITY = b"\x03"
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b"\x04"
# Sentinel to wake up input_queue.get() during shutdown.
WAKEUP = b"\x05"
class ReconfigureDistributedRequest(msgspec.Struct):

View File

@@ -264,16 +264,15 @@ class AsyncLLM(EngineClient):
def __del__(self):
self.shutdown()
def shutdown(self):
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown, cleaning up the background proc and IPC."""
shutdown_prometheus()
if renderer := getattr(self, "renderer", None):
renderer.shutdown()
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
engine_core.shutdown(timeout=timeout)
handler = getattr(self, "output_handler", None)
if handler is not None:

View File

@@ -104,8 +104,10 @@ class DPCoordinator:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address
def close(self):
self._finalizer()
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown coordinator process with configurable timeout."""
if self._finalizer.detach() is not None:
shutdown([self.proc], timeout=timeout)
class EngineState:

View File

@@ -9,6 +9,7 @@ from collections import defaultdict, deque
from collections.abc import Callable, Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from enum import IntEnum
from functools import partial
from inspect import isclass, signature
from logging import DEBUG
@@ -61,6 +62,7 @@ from vllm.v1.engine import (
from vllm.v1.engine.utils import (
EngineHandshakeMetadata,
EngineZmqAddresses,
SignalCallback,
get_device_indices,
)
from vllm.v1.executor import Executor
@@ -765,6 +767,12 @@ class EngineCore:
raise NotImplementedError
class EngineShutdownState(IntEnum):
RUNNING = 0
REQUESTED = 1
SHUTTING_DOWN = 2
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
@@ -792,6 +800,7 @@ class EngineCoreProc(EngineCore):
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
self.shutdown_state = EngineShutdownState.RUNNING
with self._perform_handshakes(
handshake_address,
@@ -1020,25 +1029,11 @@ class EngineCoreProc(EngineCore):
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: EngineCoreProc | None = None
signal_callback: SignalCallback | None = None
try:
vllm_config: VllmConfig = kwargs["vllm_config"]
parallel_config: ParallelConfig = vllm_config.parallel_config
@@ -1078,6 +1073,22 @@ class EngineCoreProc(EngineCore):
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
assert engine_core is not None
def wakeup_engine():
# Wakes up idle engine via input_queue when shutdown is requested
# Not safe in a signal handler - we may interrupt the main thread
# while it is holding the non-reentrant input_queue.mutex
engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))
signal_callback = SignalCallback(wakeup_engine)
def signal_handler(signum, frame):
engine_core.shutdown_state = EngineShutdownState.REQUESTED
signal_callback.trigger()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core.run_busy_loop()
except SystemExit:
@@ -1091,6 +1102,10 @@ class EngineCoreProc(EngineCore):
engine_core._send_engine_dead()
raise e
finally:
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)
if signal_callback is not None:
signal_callback.stop()
if engine_core is not None:
engine_core.shutdown()
@@ -1105,21 +1120,25 @@ class EngineCoreProc(EngineCore):
or bool(self.batch_queue)
)
def is_running(self) -> bool:
"""Returns true if shutdown has not been requested."""
return self.shutdown_state == EngineShutdownState.RUNNING
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
while self._handle_shutdown():
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
raise SystemExit
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.has_work():
while not self.has_work() and self.is_running():
# Notify callbacks waiting for engine to become idle.
self._notify_idle_state_callbacks()
if self.input_queue.empty():
@@ -1171,18 +1190,60 @@ class EngineCoreProc(EngineCore):
callback = self._idle_state_callbacks.pop()
callback(self)
def _handle_shutdown(self) -> bool:
# Check if shutdown was requested and handle it
if self.shutdown_state == EngineShutdownState.RUNNING:
return True
if self.shutdown_state == EngineShutdownState.REQUESTED:
shutdown_timeout = self.vllm_config.shutdown_timeout
logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout)
if shutdown_timeout == 0:
num_requests = self.scheduler.get_num_unfinished_requests()
if num_requests > 0:
logger.info("Aborting %d requests", num_requests)
aborted_reqs = self.scheduler.finish_requests(
None, RequestStatus.FINISHED_ABORTED
)
self._send_abort_outputs(aborted_reqs)
else:
num_requests = self.scheduler.get_num_unfinished_requests()
if num_requests > 0:
logger.info(
"Draining %d in-flight requests (timeout=%ds)",
num_requests,
shutdown_timeout,
)
self.shutdown_state = EngineShutdownState.SHUTTING_DOWN
# Exit when no work remaining
if not self.has_work():
logger.info("Shutdown complete")
return False
return True
def _handle_client_request(
self, request_type: EngineCoreRequestType, request: Any
) -> None:
"""Dispatch request from client."""
if request_type == EngineCoreRequestType.ADD:
if request_type == EngineCoreRequestType.WAKEUP:
return
elif request_type == EngineCoreRequestType.ADD:
req, request_wave = request
if self._reject_add_in_shutdown(req):
return
self.add_request(req, request_wave)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
client_idx, call_id, method_name, args = request
if self._reject_utility_in_shutdown(client_idx, call_id, method_name):
return
output = UtilityOutput(call_id)
# Lazily look-up utility method so that failure will be handled/returned.
get_result = lambda: (method := getattr(self, method_name)) and method(
@@ -1199,6 +1260,27 @@ class EngineCoreProc(EngineCore):
"Unrecognized input request type encountered: %s", request_type
)
def _reject_add_in_shutdown(self, request: Request) -> bool:
if self.shutdown_state == EngineShutdownState.RUNNING:
return False
logger.info("Rejecting request %s (server shutting down)", request.request_id)
self._send_abort_outputs_to_client([request.request_id], request.client_index)
return True
def _reject_utility_in_shutdown(
self, client_idx: int, call_id: int, method_name: str
) -> bool:
if self.shutdown_state == EngineShutdownState.RUNNING:
return False
logger.warning("Rejecting utility call %s (server shutting down)", method_name)
output = UtilityOutput(call_id, failure_message="Server shutting down")
self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output))
)
return True
@staticmethod
def _invoke_utility_method(
name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
@@ -1412,22 +1494,7 @@ class EngineCoreProc(EngineCore):
logger.exception(
"Unexpected error pre-processing request %s", request.request_id
)
self.output_queue.put_nowait(
(
request.client_index,
EngineCoreOutputs(
engine_index=self.engine_index,
finished_requests={request.request_id},
outputs=[
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=FinishReason.ERROR,
)
],
),
)
)
self._send_error_outputs_to_client([request.request_id], request.client_index)
def pause_scheduler(
self, mode: PauseMode = "abort", clear_cache: bool = True
@@ -1470,6 +1537,26 @@ class EngineCoreProc(EngineCore):
self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
return future
def _send_finish_outputs_to_client(
self, req_ids: list[str], client_index: int, finish_reason: FinishReason
) -> None:
outputs = [
EngineCoreOutput(req_id, [], finish_reason=finish_reason)
for req_id in req_ids
]
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
self.output_queue.put_nowait((client_index, eco))
def _send_abort_outputs_to_client(
self, req_ids: list[str], client_index: int
) -> None:
self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ABORT)
def _send_error_outputs_to_client(
self, req_ids: list[str], client_index: int
) -> None:
self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ERROR)
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
# TODO(nick) this will be moved inside the scheduler
if aborted_reqs:
@@ -1478,12 +1565,7 @@ class EngineCoreProc(EngineCore):
for req_id, client_index in aborted_reqs:
by_client[client_index].add(req_id)
for client_index, req_ids in by_client.items():
outputs = [
EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
for req_id in req_ids
]
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
self.output_queue.put_nowait((client_index, eco))
self._send_abort_outputs_to_client(list(req_ids), client_index)
class DPEngineCoreProc(EngineCoreProc):
@@ -1601,7 +1683,7 @@ class DPEngineCoreProc(EngineCoreProc):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
while self._handle_shutdown():
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
@@ -1649,6 +1731,8 @@ class DPEngineCoreProc(EngineCoreProc):
self.current_wave += 1
self.step_counter = 0
raise SystemExit
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 32 steps.
self.step_counter += 1

View File

@@ -128,7 +128,7 @@ class EngineCoreClient(ABC):
return AsyncMPClient(*client_args)
@abstractmethod
def shutdown(self): ...
def shutdown(self, timeout: float | None = None) -> None: ...
def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError
@@ -298,7 +298,7 @@ class InprocClient(EngineCoreClient):
if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids)
def shutdown(self) -> None:
def shutdown(self, timeout: float | None = None) -> None:
self.engine_core.shutdown()
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
@@ -390,9 +390,9 @@ class BackgroundResources:
self.engine_dead = True
if self.engine_manager is not None:
self.engine_manager.close()
self.engine_manager.shutdown()
if self.coordinator is not None:
self.coordinator.close()
self.coordinator.shutdown()
if isinstance(self.output_socket, zmq.asyncio.Socket):
# Async case.
@@ -581,10 +581,7 @@ class MPClient(EngineCoreClient):
)
with launch_core_engines(
vllm_config,
executor_class,
log_stats,
addresses,
vllm_config, executor_class, log_stats, addresses
) as (engine_manager, coordinator, addresses):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
@@ -649,9 +646,12 @@ class MPClient(EngineCoreClient):
if not success:
self._finalizer()
def shutdown(self):
# Terminate background resources.
self._finalizer()
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown engine manager under timeout and clean up resources."""
if self._finalizer.detach() is not None:
if self.resources.engine_manager is not None:
self.resources.engine_manager.shutdown(timeout=timeout)
self.resources()
def _format_exception(self, e: Exception) -> Exception:
"""If errored, use EngineDeadError so root cause is clear."""
@@ -695,7 +695,7 @@ class MPClient(EngineCoreClient):
sentinels = [proc.sentinel for proc in engine_processes]
died = multiprocessing.connection.wait(sentinels)
_self = self_ref()
if not _self or _self.resources.engine_dead:
if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
return
_self.resources.engine_dead = True
proc_name = next(

View File

@@ -3,8 +3,9 @@
import contextlib
import os
import threading
import weakref
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
@@ -146,11 +147,12 @@ class CoreEngineProcManager:
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
self.shutdown()
def close(self):
"""Shutdown all procs."""
self._finalizer()
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown engine core processes with configurable timeout."""
if self._finalizer.detach() is not None:
shutdown(self.processes, timeout=timeout)
def join_first(self):
"""Wait for any process to exit."""
@@ -168,6 +170,33 @@ class CoreEngineProcManager:
}
class SignalCallback:
"""Safely trigger a callback from signal handler context via a dedicated thread."""
def __init__(self, callback: Callable[[], None]):
self._callback = callback
self._event = threading.Event()
self._stopped = False
self._thread = threading.Thread(
target=self._run,
daemon=True,
name="signal-callback",
)
self._thread.start()
def _run(self):
self._event.wait()
if not self._stopped:
self._callback()
def trigger(self):
self._event.set()
def stop(self):
self._stopped = True
self._event.set()
@contextlib.contextmanager
def set_device_control_env_var(
vllm_config: VllmConfig, local_dp_rank: int
@@ -763,7 +792,7 @@ class CoreEngineActorManager:
def get_run_refs(self):
return self.run_refs
def close(self):
def shutdown(self, timeout: float | None = None) -> None:
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:

View File

@@ -220,8 +220,10 @@ class APIServerProcessManager:
# The extra processes are managed by their owners
self._finalizer = weakref.finalize(self, shutdown, self.processes)
def close(self) -> None:
self._finalizer()
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown API server processes with configurable timeout"""
if self._finalizer.detach() is not None:
shutdown(self.processes, timeout=timeout)
def wait_for_completion_or_failure(
@@ -288,25 +290,30 @@ def wait_for_completion_or_failure(
except Exception as e:
logger.exception("Exception occurred while running API servers: %s", str(e))
raise
finally:
logger.info("Terminating remaining processes ...")
api_server_manager.close()
if coordinator:
coordinator.close()
if engine_manager:
engine_manager.close()
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
"""Shutdown processes with timeout.
Args:
procs: List of processes to shutdown
timeout: Maximum time in seconds to wait for graceful shutdown
"""
if timeout is None:
timeout = 0.0
# Allow at least 5 seconds for remaining procs to terminate.
timeout = max(timeout, 5.0)
# Shutdown the process.
for proc in procs:
if proc.is_alive():
proc.terminate()
# Allow 5 seconds for remaining procs to terminate.
deadline = time.monotonic() + 5
# Allow time for remaining procs to terminate.
deadline = time.monotonic() + timeout
for proc in procs:
remaining = deadline - time.monotonic()
if remaining <= 0: