[Frontend][Core] Re-add shutdown timeout - allowing in-flight requests to finish (#36666)
Signed-off-by: Mark McLoughlin <markmc@redhat.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -1,14 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for shutdown behavior, timeout, and signal handling."""
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import psutil
|
||||
import pytest
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
@@ -18,6 +24,101 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
_IS_ROCM = current_platform.is_rocm()
|
||||
_SERVER_STARTUP_TIMEOUT = 120
|
||||
_PROCESS_EXIT_TIMEOUT = 15
|
||||
_SHUTDOWN_DETECTION_TIMEOUT = 10
|
||||
_CHILD_CLEANUP_TIMEOUT = 10
|
||||
|
||||
|
||||
def _get_child_pids(parent_pid: int) -> list[int]:
|
||||
try:
|
||||
parent = psutil.Process(parent_pid)
|
||||
return [c.pid for c in parent.children(recursive=True)]
|
||||
except psutil.NoSuchProcess:
|
||||
return []
|
||||
|
||||
|
||||
async def _assert_children_cleaned_up(
|
||||
child_pids: list[int],
|
||||
timeout: float = _CHILD_CLEANUP_TIMEOUT,
|
||||
):
|
||||
"""Wait for child processes to exit and fail if any remain."""
|
||||
if not child_pids:
|
||||
return
|
||||
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
still_alive = []
|
||||
for pid in child_pids:
|
||||
try:
|
||||
p = psutil.Process(pid)
|
||||
if p.is_running() and p.status() != psutil.STATUS_ZOMBIE:
|
||||
still_alive.append(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
if not still_alive:
|
||||
return
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
pytest.fail(
|
||||
f"Child processes {still_alive} still alive after {timeout}s. "
|
||||
f"Process cleanup may not be working correctly."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShutdownState:
|
||||
got_503: bool = False
|
||||
got_500: bool = False
|
||||
requests_after_sigterm: int = 0
|
||||
aborted_requests: int = 0
|
||||
connection_errors: int = 0
|
||||
stop_requesting: bool = False
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
async def _concurrent_request_loop(
|
||||
client: openai.AsyncOpenAI,
|
||||
state: ShutdownState,
|
||||
sigterm_sent: asyncio.Event | None = None,
|
||||
concurrency: int = 10,
|
||||
):
|
||||
"""Run multiple concurrent requests to keep the server busy."""
|
||||
|
||||
async def single_request():
|
||||
while not state.stop_requesting:
|
||||
try:
|
||||
response = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Write a story: ",
|
||||
max_tokens=200,
|
||||
)
|
||||
if sigterm_sent is not None and sigterm_sent.is_set():
|
||||
state.requests_after_sigterm += 1
|
||||
# Check if any choice has finish_reason='abort'
|
||||
if any(choice.finish_reason == "abort" for choice in response.choices):
|
||||
state.aborted_requests += 1
|
||||
except openai.APIStatusError as e:
|
||||
if e.status_code == 503:
|
||||
state.got_503 = True
|
||||
elif e.status_code == 500:
|
||||
state.got_500 = True
|
||||
else:
|
||||
state.errors.append(f"API error: {e}")
|
||||
except (openai.APIConnectionError, httpx.RemoteProtocolError):
|
||||
state.connection_errors += 1
|
||||
if sigterm_sent is not None and sigterm_sent.is_set():
|
||||
break
|
||||
except Exception as e:
|
||||
state.errors.append(f"Unexpected error: {e}")
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
tasks = [asyncio.create_task(single_request()) for _ in range(concurrency)]
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
finally:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -103,3 +204,361 @@ async def test_shutdown_on_engine_failure():
|
||||
|
||||
return_code = proc.wait(timeout=_PROCESS_EXIT_TIMEOUT)
|
||||
assert return_code is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_timeout_completes_requests():
|
||||
"""Verify wait timeout: new requests rejected, in-flight requests complete."""
|
||||
server_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.05",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--shutdown-timeout",
|
||||
"30",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
proc = remote_server.proc
|
||||
child_pids = _get_child_pids(proc.pid)
|
||||
|
||||
state = ShutdownState()
|
||||
sigterm_sent = asyncio.Event()
|
||||
|
||||
request_task = asyncio.create_task(
|
||||
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
sigterm_sent.set()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
state.stop_requesting = True
|
||||
if not request_task.done():
|
||||
request_task.cancel()
|
||||
await asyncio.gather(request_task, return_exceptions=True)
|
||||
|
||||
# wait timeout should complete in-flight requests
|
||||
assert state.requests_after_sigterm > 0, (
|
||||
f"Wait timeout should complete in-flight requests. "
|
||||
f"503: {state.got_503}, 500: {state.got_500}, "
|
||||
f"conn_errors: {state.connection_errors}, errors: {state.errors}"
|
||||
)
|
||||
# server must stop accepting new requests (503, 500, or connection close)
|
||||
assert state.got_503 or state.got_500 or state.connection_errors > 0, (
|
||||
f"Server should stop accepting requests. "
|
||||
f"completed: {state.requests_after_sigterm}, errors: {state.errors}"
|
||||
)
|
||||
|
||||
await _assert_children_cleaned_up(child_pids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("wait_for_engine_idle", [0.0, 2.0])
|
||||
async def test_abort_timeout_exits_quickly(wait_for_engine_idle: float):
|
||||
server_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.05",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--shutdown-timeout",
|
||||
"0",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
|
||||
proc = remote_server.proc
|
||||
child_pids = _get_child_pids(proc.pid)
|
||||
|
||||
if wait_for_engine_idle > 0:
|
||||
client = remote_server.get_async_client()
|
||||
# Send requests to ensure engine is fully initialized
|
||||
for _ in range(2):
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Test request: ",
|
||||
max_tokens=10,
|
||||
)
|
||||
# Wait for engine to become idle
|
||||
await asyncio.sleep(wait_for_engine_idle)
|
||||
|
||||
start_time = time.time()
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
|
||||
# abort timeout (0) should exit promptly
|
||||
for _ in range(20):
|
||||
if proc.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
if proc.poll() is None:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
pytest.fail("Process did not exit after SIGTERM with abort timeout")
|
||||
|
||||
exit_time = time.time() - start_time
|
||||
assert exit_time < 2, f"Default shutdown took too long: {exit_time:.1f}s"
|
||||
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
|
||||
|
||||
await _assert_children_cleaned_up(child_pids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_timeout_with_short_duration():
|
||||
"""Verify server exits cleanly with a short wait timeout."""
|
||||
wait_timeout = 3
|
||||
server_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.05",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--shutdown-timeout",
|
||||
str(wait_timeout),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
proc = remote_server.proc
|
||||
child_pids = _get_child_pids(proc.pid)
|
||||
|
||||
state = ShutdownState()
|
||||
request_task = asyncio.create_task(
|
||||
_concurrent_request_loop(client, state, concurrency=3)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
start_time = time.time()
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
|
||||
# server should exit within wait_timeout + buffer
|
||||
max_wait = wait_timeout + 15
|
||||
for _ in range(int(max_wait * 10)):
|
||||
if proc.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
exit_time = time.time() - start_time
|
||||
|
||||
state.stop_requesting = True
|
||||
if not request_task.done():
|
||||
request_task.cancel()
|
||||
await asyncio.gather(request_task, return_exceptions=True)
|
||||
|
||||
if proc.poll() is None:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
pytest.fail(f"Process did not exit within {max_wait}s after SIGTERM")
|
||||
|
||||
assert exit_time < wait_timeout + 10, (
|
||||
f"Took too long to exit ({exit_time:.1f}s), expected <{wait_timeout + 10}s"
|
||||
)
|
||||
assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}"
|
||||
|
||||
await _assert_children_cleaned_up(child_pids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_timeout_fails_inflight_requests():
|
||||
"""Verify abort timeout (0) immediately aborts in-flight requests."""
|
||||
server_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.05",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--shutdown-timeout",
|
||||
"0",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
proc = remote_server.proc
|
||||
child_pids = _get_child_pids(proc.pid)
|
||||
|
||||
state = ShutdownState()
|
||||
sigterm_sent = asyncio.Event()
|
||||
|
||||
request_task = asyncio.create_task(
|
||||
_concurrent_request_loop(client, state, sigterm_sent, concurrency=10)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
sigterm_sent.set()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(request_task, timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
state.stop_requesting = True
|
||||
if not request_task.done():
|
||||
request_task.cancel()
|
||||
await asyncio.gather(request_task, return_exceptions=True)
|
||||
|
||||
# With abort timeout (0), requests should be aborted (finish_reason='abort')
|
||||
# or rejected (connection errors or API errors)
|
||||
assert (
|
||||
state.aborted_requests > 0
|
||||
or state.connection_errors > 0
|
||||
or state.got_500
|
||||
or state.got_503
|
||||
), (
|
||||
f"Abort timeout should cause request aborts or failures. "
|
||||
f"aborted: {state.aborted_requests}, "
|
||||
f"503: {state.got_503}, 500: {state.got_500}, "
|
||||
f"conn_errors: {state.connection_errors}, "
|
||||
f"completed: {state.requests_after_sigterm}"
|
||||
)
|
||||
|
||||
# Verify fast shutdown
|
||||
start_time = time.time()
|
||||
for _ in range(100):
|
||||
if proc.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
exit_time = time.time() - start_time
|
||||
assert exit_time < 10, f"Abort timeout shutdown took too long: {exit_time:.1f}s"
|
||||
|
||||
await _assert_children_cleaned_up(child_pids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_rejection_during_shutdown():
|
||||
"""Verify new requests are rejected with error during shutdown."""
|
||||
server_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.05",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--shutdown-timeout",
|
||||
"30",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
proc = remote_server.proc
|
||||
child_pids = _get_child_pids(proc.pid)
|
||||
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Try to send new requests - they should be rejected
|
||||
rejected_count = 0
|
||||
for _ in range(10):
|
||||
try:
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME, prompt="Hello", max_tokens=10
|
||||
)
|
||||
except (
|
||||
openai.APIStatusError,
|
||||
openai.APIConnectionError,
|
||||
httpx.RemoteProtocolError,
|
||||
):
|
||||
rejected_count += 1
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert rejected_count > 0, (
|
||||
f"Expected requests to be rejected during shutdown, "
|
||||
f"but {rejected_count} were rejected out of 10"
|
||||
)
|
||||
|
||||
await _assert_children_cleaned_up(child_pids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_api_server_shutdown():
|
||||
"""Verify shutdown works with multiple API servers."""
|
||||
server_args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"256",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.05",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--shutdown-timeout",
|
||||
"30",
|
||||
"--api-server-count",
|
||||
"2",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, server_args, auto_port=True) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
proc = remote_server.proc
|
||||
child_pids = _get_child_pids(proc.pid)
|
||||
|
||||
assert len(child_pids) >= 2, (
|
||||
f"Expected at least 2 child processes, got {len(child_pids)}"
|
||||
)
|
||||
|
||||
state = ShutdownState()
|
||||
sigterm_sent = asyncio.Event()
|
||||
|
||||
# Start concurrent requests across both API servers
|
||||
request_task = asyncio.create_task(
|
||||
_concurrent_request_loop(client, state, sigterm_sent, concurrency=8)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send SIGTERM to parent - should propagate to all children
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
sigterm_sent.set()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
state.stop_requesting = True
|
||||
if not request_task.done():
|
||||
request_task.cancel()
|
||||
await asyncio.gather(request_task, return_exceptions=True)
|
||||
|
||||
for _ in range(300): # up to 30 seconds
|
||||
if proc.poll() is not None:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
if proc.poll() is None:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
pytest.fail("Process did not exit after SIGTERM")
|
||||
|
||||
await _assert_children_cleaned_up(child_pids)
|
||||
|
||||
@@ -79,7 +79,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||
finally:
|
||||
# Always clean up the processes
|
||||
print("Cleaning up processes...")
|
||||
manager.close()
|
||||
manager.shutdown()
|
||||
|
||||
# Give processes time to terminate
|
||||
time.sleep(0.2)
|
||||
@@ -111,6 +111,8 @@ def test_wait_for_completion_or_failure(api_server_args):
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
except Exception as e:
|
||||
result["exception"] = e
|
||||
finally:
|
||||
manager.shutdown()
|
||||
|
||||
# Start a thread to run wait_for_completion_or_failure
|
||||
wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
|
||||
@@ -143,7 +145,7 @@ def test_wait_for_completion_or_failure(api_server_args):
|
||||
assert not proc.is_alive(), f"Process {i} should not be alive"
|
||||
|
||||
finally:
|
||||
manager.close()
|
||||
manager.shutdown()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@@ -174,11 +176,14 @@ def test_normal_completion(api_server_args):
|
||||
# since all processes have already
|
||||
# terminated, it should return immediately
|
||||
# with no error
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
try:
|
||||
wait_for_completion_or_failure(api_server_manager=manager)
|
||||
finally:
|
||||
manager.shutdown()
|
||||
|
||||
finally:
|
||||
# Clean up just in case
|
||||
manager.close()
|
||||
manager.shutdown()
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@@ -201,7 +206,7 @@ def test_external_process_monitoring(api_server_args):
|
||||
def __init__(self, proc):
|
||||
self.proc = proc
|
||||
|
||||
def close(self):
|
||||
def shutdown(self):
|
||||
if self.proc.is_alive():
|
||||
self.proc.terminate()
|
||||
self.proc.join(timeout=0.5)
|
||||
@@ -226,6 +231,9 @@ def test_external_process_monitoring(api_server_args):
|
||||
)
|
||||
except Exception as e:
|
||||
result["exception"] = e
|
||||
finally:
|
||||
manager.shutdown()
|
||||
mock_coordinator.shutdown()
|
||||
|
||||
# Start a thread to run wait_for_completion_or_failure
|
||||
wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
|
||||
@@ -259,6 +267,6 @@ def test_external_process_monitoring(api_server_args):
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
manager.close()
|
||||
mock_coordinator.close()
|
||||
manager.shutdown()
|
||||
mock_coordinator.shutdown()
|
||||
time.sleep(0.2)
|
||||
|
||||
@@ -327,6 +327,12 @@ class VllmConfig:
|
||||
weight_transfer_config: WeightTransferConfig | None = None
|
||||
"""The configurations for weight transfer during RL training."""
|
||||
|
||||
shutdown_timeout: int = Field(default=0, ge=0)
|
||||
"""Shutdown grace period for in-flight requests. Shutdown will be delayed for
|
||||
up to this amount of time to allow already-running requests to complete. Any
|
||||
remaining requests are aborted once the timeout is reached.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@@ -606,6 +606,8 @@ class EngineArgs:
|
||||
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
|
||||
tokens_only: bool = False
|
||||
|
||||
shutdown_timeout: int = 0
|
||||
|
||||
weight_transfer_config: WeightTransferConfig | None = get_field(
|
||||
VllmConfig,
|
||||
"weight_transfer_config",
|
||||
@@ -1308,6 +1310,14 @@ class EngineArgs:
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shutdown-timeout",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -1916,6 +1926,7 @@ class EngineArgs:
|
||||
optimization_level=self.optimization_level,
|
||||
performance_mode=self.performance_mode,
|
||||
weight_transfer_config=self.weight_transfer_config,
|
||||
shutdown_timeout=self.shutdown_timeout,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
@@ -200,6 +200,11 @@ class EngineClient(ABC):
|
||||
"""Return whether the engine is currently paused."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown the engine with optional timeout."""
|
||||
...
|
||||
|
||||
async def scale_elastic_ep(
|
||||
self, new_data_parallel_size: int, drain_timeout: int = 300
|
||||
) -> None:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
import time
|
||||
|
||||
import uvloop
|
||||
|
||||
@@ -222,8 +223,12 @@ def run_headless(args: argparse.Namespace):
|
||||
try:
|
||||
engine_manager.join_first()
|
||||
finally:
|
||||
timeout = None
|
||||
if shutdown_requested:
|
||||
timeout = vllm_config.shutdown_timeout
|
||||
logger.info("Waiting up to %d seconds for processes to exit", timeout)
|
||||
engine_manager.shutdown(timeout=timeout)
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
@@ -234,6 +239,19 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
shutdown_requested = False
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
logger.debug("Received %d signal.", signum)
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
@@ -295,11 +313,29 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
try:
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator,
|
||||
)
|
||||
finally:
|
||||
timeout = shutdown_by = None
|
||||
if shutdown_requested:
|
||||
timeout = vllm_config.shutdown_timeout
|
||||
shutdown_by = time.monotonic() + timeout
|
||||
logger.info("Waiting up to %d seconds for processes to exit", timeout)
|
||||
|
||||
def to_timeout(deadline: float | None) -> float | None:
|
||||
return (
|
||||
deadline if deadline is None else max(deadline - time.monotonic(), 0.0)
|
||||
)
|
||||
|
||||
api_server_manager.shutdown(timeout=timeout)
|
||||
if local_engine_manager:
|
||||
local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
|
||||
if coordinator:
|
||||
coordinator.shutdown(timeout=to_timeout(shutdown_by))
|
||||
|
||||
|
||||
def run_api_server_worker_proc(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import signal
|
||||
import socket
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
@@ -91,12 +92,10 @@ async def serve_http(
|
||||
)
|
||||
)
|
||||
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
watchdog_task.cancel()
|
||||
if ssl_cert_refresher:
|
||||
ssl_cert_refresher.stop()
|
||||
shutdown_event.set()
|
||||
|
||||
async def dummy_shutdown() -> None:
|
||||
pass
|
||||
@@ -104,6 +103,24 @@ async def serve_http(
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
async def handle_shutdown() -> None:
|
||||
await shutdown_event.wait()
|
||||
|
||||
engine_client = app.state.engine_client
|
||||
timeout = engine_client.vllm_config.shutdown_timeout
|
||||
|
||||
await loop.run_in_executor(
|
||||
None, partial(engine_client.shutdown, timeout=timeout)
|
||||
)
|
||||
|
||||
server.should_exit = True
|
||||
server_task.cancel()
|
||||
watchdog_task.cancel()
|
||||
if ssl_cert_refresher:
|
||||
ssl_cert_refresher.stop()
|
||||
|
||||
shutdown_task = loop.create_task(handle_shutdown())
|
||||
|
||||
try:
|
||||
await server_task
|
||||
return dummy_shutdown()
|
||||
@@ -120,6 +137,7 @@ async def serve_http(
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
finally:
|
||||
shutdown_task.cancel()
|
||||
watchdog_task.cancel()
|
||||
|
||||
|
||||
|
||||
@@ -226,6 +226,8 @@ class EngineCoreRequestType(enum.Enum):
|
||||
UTILITY = b"\x03"
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b"\x04"
|
||||
# Sentinel to wake up input_queue.get() during shutdown.
|
||||
WAKEUP = b"\x05"
|
||||
|
||||
|
||||
class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
|
||||
@@ -264,16 +264,15 @@ class AsyncLLM(EngineClient):
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown, cleaning up the background proc and IPC."""
|
||||
|
||||
shutdown_prometheus()
|
||||
|
||||
if renderer := getattr(self, "renderer", None):
|
||||
renderer.shutdown()
|
||||
|
||||
if engine_core := getattr(self, "engine_core", None):
|
||||
engine_core.shutdown()
|
||||
engine_core.shutdown(timeout=timeout)
|
||||
|
||||
handler = getattr(self, "output_handler", None)
|
||||
if handler is not None:
|
||||
|
||||
@@ -104,8 +104,10 @@ class DPCoordinator:
|
||||
"""Returns tuple of ZMQ input address, output address."""
|
||||
return self.coord_in_address, self.coord_out_address
|
||||
|
||||
def close(self):
|
||||
self._finalizer()
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown coordinator process with configurable timeout."""
|
||||
if self._finalizer.detach() is not None:
|
||||
shutdown([self.proc], timeout=timeout)
|
||||
|
||||
|
||||
class EngineState:
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from enum import IntEnum
|
||||
from functools import partial
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
@@ -61,6 +62,7 @@ from vllm.v1.engine import (
|
||||
from vllm.v1.engine.utils import (
|
||||
EngineHandshakeMetadata,
|
||||
EngineZmqAddresses,
|
||||
SignalCallback,
|
||||
get_device_indices,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
@@ -765,6 +767,12 @@ class EngineCore:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EngineShutdownState(IntEnum):
|
||||
RUNNING = 0
|
||||
REQUESTED = 1
|
||||
SHUTTING_DOWN = 2
|
||||
|
||||
|
||||
class EngineCoreProc(EngineCore):
|
||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||
|
||||
@@ -792,6 +800,7 @@ class EngineCoreProc(EngineCore):
|
||||
self.engine_index = engine_index
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
self.shutdown_state = EngineShutdownState.RUNNING
|
||||
|
||||
with self._perform_handshakes(
|
||||
handshake_address,
|
||||
@@ -1020,25 +1029,11 @@ class EngineCoreProc(EngineCore):
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Signal handler used for graceful termination.
|
||||
# SystemExit exception is only raised once to allow this and worker
|
||||
# processes to terminate without error
|
||||
shutdown_requested = False
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit()
|
||||
|
||||
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core: EngineCoreProc | None = None
|
||||
signal_callback: SignalCallback | None = None
|
||||
try:
|
||||
vllm_config: VllmConfig = kwargs["vllm_config"]
|
||||
parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||
@@ -1078,6 +1073,22 @@ class EngineCoreProc(EngineCore):
|
||||
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
|
||||
|
||||
assert engine_core is not None
|
||||
|
||||
def wakeup_engine():
|
||||
# Wakes up idle engine via input_queue when shutdown is requested
|
||||
# Not safe in a signal handler - we may interrupt the main thread
|
||||
# while it is holding the non-reentrant input_queue.mutex
|
||||
engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))
|
||||
|
||||
signal_callback = SignalCallback(wakeup_engine)
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
engine_core.shutdown_state = EngineShutdownState.REQUESTED
|
||||
signal_callback.trigger()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
@@ -1091,6 +1102,10 @@ class EngineCoreProc(EngineCore):
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
if signal_callback is not None:
|
||||
signal_callback.stop()
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
@@ -1105,21 +1120,25 @@ class EngineCoreProc(EngineCore):
|
||||
or bool(self.batch_queue)
|
||||
)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Returns true if shutdown has not been requested."""
|
||||
return self.shutdown_state == EngineShutdownState.RUNNING
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
while self._handle_shutdown():
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
raise SystemExit
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.has_work():
|
||||
while not self.has_work() and self.is_running():
|
||||
# Notify callbacks waiting for engine to become idle.
|
||||
self._notify_idle_state_callbacks()
|
||||
if self.input_queue.empty():
|
||||
@@ -1171,18 +1190,60 @@ class EngineCoreProc(EngineCore):
|
||||
callback = self._idle_state_callbacks.pop()
|
||||
callback(self)
|
||||
|
||||
def _handle_shutdown(self) -> bool:
|
||||
# Check if shutdown was requested and handle it
|
||||
if self.shutdown_state == EngineShutdownState.RUNNING:
|
||||
return True
|
||||
|
||||
if self.shutdown_state == EngineShutdownState.REQUESTED:
|
||||
shutdown_timeout = self.vllm_config.shutdown_timeout
|
||||
|
||||
logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout)
|
||||
|
||||
if shutdown_timeout == 0:
|
||||
num_requests = self.scheduler.get_num_unfinished_requests()
|
||||
if num_requests > 0:
|
||||
logger.info("Aborting %d requests", num_requests)
|
||||
aborted_reqs = self.scheduler.finish_requests(
|
||||
None, RequestStatus.FINISHED_ABORTED
|
||||
)
|
||||
self._send_abort_outputs(aborted_reqs)
|
||||
else:
|
||||
num_requests = self.scheduler.get_num_unfinished_requests()
|
||||
if num_requests > 0:
|
||||
logger.info(
|
||||
"Draining %d in-flight requests (timeout=%ds)",
|
||||
num_requests,
|
||||
shutdown_timeout,
|
||||
)
|
||||
|
||||
self.shutdown_state = EngineShutdownState.SHUTTING_DOWN
|
||||
|
||||
# Exit when no work remaining
|
||||
if not self.has_work():
|
||||
logger.info("Shutdown complete")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_client_request(
|
||||
self, request_type: EngineCoreRequestType, request: Any
|
||||
) -> None:
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
if request_type == EngineCoreRequestType.WAKEUP:
|
||||
return
|
||||
elif request_type == EngineCoreRequestType.ADD:
|
||||
req, request_wave = request
|
||||
if self._reject_add_in_shutdown(req):
|
||||
return
|
||||
self.add_request(req, request_wave)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
client_idx, call_id, method_name, args = request
|
||||
if self._reject_utility_in_shutdown(client_idx, call_id, method_name):
|
||||
return
|
||||
output = UtilityOutput(call_id)
|
||||
# Lazily look-up utility method so that failure will be handled/returned.
|
||||
get_result = lambda: (method := getattr(self, method_name)) and method(
|
||||
@@ -1199,6 +1260,27 @@ class EngineCoreProc(EngineCore):
|
||||
"Unrecognized input request type encountered: %s", request_type
|
||||
)
|
||||
|
||||
def _reject_add_in_shutdown(self, request: Request) -> bool:
|
||||
if self.shutdown_state == EngineShutdownState.RUNNING:
|
||||
return False
|
||||
|
||||
logger.info("Rejecting request %s (server shutting down)", request.request_id)
|
||||
self._send_abort_outputs_to_client([request.request_id], request.client_index)
|
||||
return True
|
||||
|
||||
def _reject_utility_in_shutdown(
|
||||
self, client_idx: int, call_id: int, method_name: str
|
||||
) -> bool:
|
||||
if self.shutdown_state == EngineShutdownState.RUNNING:
|
||||
return False
|
||||
|
||||
logger.warning("Rejecting utility call %s (server shutting down)", method_name)
|
||||
output = UtilityOutput(call_id, failure_message="Server shutting down")
|
||||
self.output_queue.put_nowait(
|
||||
(client_idx, EngineCoreOutputs(utility_output=output))
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _invoke_utility_method(
|
||||
name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable
|
||||
@@ -1412,22 +1494,7 @@ class EngineCoreProc(EngineCore):
|
||||
logger.exception(
|
||||
"Unexpected error pre-processing request %s", request.request_id
|
||||
)
|
||||
self.output_queue.put_nowait(
|
||||
(
|
||||
request.client_index,
|
||||
EngineCoreOutputs(
|
||||
engine_index=self.engine_index,
|
||||
finished_requests={request.request_id},
|
||||
outputs=[
|
||||
EngineCoreOutput(
|
||||
request_id=request.request_id,
|
||||
new_token_ids=[],
|
||||
finish_reason=FinishReason.ERROR,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
self._send_error_outputs_to_client([request.request_id], request.client_index)
|
||||
|
||||
def pause_scheduler(
|
||||
self, mode: PauseMode = "abort", clear_cache: bool = True
|
||||
@@ -1470,6 +1537,26 @@ class EngineCoreProc(EngineCore):
|
||||
self._idle_state_callbacks.append(partial(engine_idle_callback, future=future))
|
||||
return future
|
||||
|
||||
def _send_finish_outputs_to_client(
|
||||
self, req_ids: list[str], client_index: int, finish_reason: FinishReason
|
||||
) -> None:
|
||||
outputs = [
|
||||
EngineCoreOutput(req_id, [], finish_reason=finish_reason)
|
||||
for req_id in req_ids
|
||||
]
|
||||
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
|
||||
self.output_queue.put_nowait((client_index, eco))
|
||||
|
||||
def _send_abort_outputs_to_client(
|
||||
self, req_ids: list[str], client_index: int
|
||||
) -> None:
|
||||
self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ABORT)
|
||||
|
||||
def _send_error_outputs_to_client(
|
||||
self, req_ids: list[str], client_index: int
|
||||
) -> None:
|
||||
self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ERROR)
|
||||
|
||||
def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None:
|
||||
# TODO(nick) this will be moved inside the scheduler
|
||||
if aborted_reqs:
|
||||
@@ -1478,12 +1565,7 @@ class EngineCoreProc(EngineCore):
|
||||
for req_id, client_index in aborted_reqs:
|
||||
by_client[client_index].add(req_id)
|
||||
for client_index, req_ids in by_client.items():
|
||||
outputs = [
|
||||
EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT)
|
||||
for req_id in req_ids
|
||||
]
|
||||
eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs)
|
||||
self.output_queue.put_nowait((client_index, eco))
|
||||
self._send_abort_outputs_to_client(list(req_ids), client_index)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
@@ -1601,7 +1683,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
while self._handle_shutdown():
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
|
||||
@@ -1649,6 +1731,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.current_wave += 1
|
||||
self.step_counter = 0
|
||||
|
||||
raise SystemExit
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||
self.step_counter += 1
|
||||
|
||||
@@ -128,7 +128,7 @@ class EngineCoreClient(ABC):
|
||||
return AsyncMPClient(*client_args)
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self): ...
|
||||
def shutdown(self, timeout: float | None = None) -> None: ...
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
raise NotImplementedError
|
||||
@@ -298,7 +298,7 @@ class InprocClient(EngineCoreClient):
|
||||
if len(request_ids) > 0:
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
self.engine_core.shutdown()
|
||||
|
||||
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
|
||||
@@ -390,9 +390,9 @@ class BackgroundResources:
|
||||
|
||||
self.engine_dead = True
|
||||
if self.engine_manager is not None:
|
||||
self.engine_manager.close()
|
||||
self.engine_manager.shutdown()
|
||||
if self.coordinator is not None:
|
||||
self.coordinator.close()
|
||||
self.coordinator.shutdown()
|
||||
|
||||
if isinstance(self.output_socket, zmq.asyncio.Socket):
|
||||
# Async case.
|
||||
@@ -581,10 +581,7 @@ class MPClient(EngineCoreClient):
|
||||
)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
log_stats,
|
||||
addresses,
|
||||
vllm_config, executor_class, log_stats, addresses
|
||||
) as (engine_manager, coordinator, addresses):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
@@ -649,9 +646,12 @@ class MPClient(EngineCoreClient):
|
||||
if not success:
|
||||
self._finalizer()
|
||||
|
||||
def shutdown(self):
|
||||
# Terminate background resources.
|
||||
self._finalizer()
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown engine manager under timeout and clean up resources."""
|
||||
if self._finalizer.detach() is not None:
|
||||
if self.resources.engine_manager is not None:
|
||||
self.resources.engine_manager.shutdown(timeout=timeout)
|
||||
self.resources()
|
||||
|
||||
def _format_exception(self, e: Exception) -> Exception:
|
||||
"""If errored, use EngineDeadError so root cause is clear."""
|
||||
@@ -695,7 +695,7 @@ class MPClient(EngineCoreClient):
|
||||
sentinels = [proc.sentinel for proc in engine_processes]
|
||||
died = multiprocessing.connection.wait(sentinels)
|
||||
_self = self_ref()
|
||||
if not _self or _self.resources.engine_dead:
|
||||
if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
|
||||
return
|
||||
_self.resources.engine_dead = True
|
||||
proc_name = next(
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import threading
|
||||
import weakref
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
@@ -146,11 +147,12 @@ class CoreEngineProcManager:
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
self.close()
|
||||
self.shutdown()
|
||||
|
||||
def close(self):
|
||||
"""Shutdown all procs."""
|
||||
self._finalizer()
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown engine core processes with configurable timeout."""
|
||||
if self._finalizer.detach() is not None:
|
||||
shutdown(self.processes, timeout=timeout)
|
||||
|
||||
def join_first(self):
|
||||
"""Wait for any process to exit."""
|
||||
@@ -168,6 +170,33 @@ class CoreEngineProcManager:
|
||||
}
|
||||
|
||||
|
||||
class SignalCallback:
|
||||
"""Safely trigger a callback from signal handler context via a dedicated thread."""
|
||||
|
||||
def __init__(self, callback: Callable[[], None]):
|
||||
self._callback = callback
|
||||
self._event = threading.Event()
|
||||
self._stopped = False
|
||||
self._thread = threading.Thread(
|
||||
target=self._run,
|
||||
daemon=True,
|
||||
name="signal-callback",
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _run(self):
|
||||
self._event.wait()
|
||||
if not self._stopped:
|
||||
self._callback()
|
||||
|
||||
def trigger(self):
|
||||
self._event.set()
|
||||
|
||||
def stop(self):
|
||||
self._stopped = True
|
||||
self._event.set()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_device_control_env_var(
|
||||
vllm_config: VllmConfig, local_dp_rank: int
|
||||
@@ -763,7 +792,7 @@ class CoreEngineActorManager:
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def close(self):
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
import ray
|
||||
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
|
||||
@@ -220,8 +220,10 @@ class APIServerProcessManager:
|
||||
# The extra processes are managed by their owners
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
def close(self) -> None:
|
||||
self._finalizer()
|
||||
def shutdown(self, timeout: float | None = None) -> None:
|
||||
"""Shutdown API server processes with configurable timeout"""
|
||||
if self._finalizer.detach() is not None:
|
||||
shutdown(self.processes, timeout=timeout)
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
@@ -288,25 +290,30 @@ def wait_for_completion_or_failure(
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while running API servers: %s", str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info("Terminating remaining processes ...")
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if engine_manager:
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
# else the gc cannot collect the object.
|
||||
def shutdown(procs: list[BaseProcess]):
|
||||
def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None:
|
||||
"""Shutdown processes with timeout.
|
||||
|
||||
Args:
|
||||
procs: List of processes to shutdown
|
||||
timeout: Maximum time in seconds to wait for graceful shutdown
|
||||
"""
|
||||
if timeout is None:
|
||||
timeout = 0.0
|
||||
|
||||
# Allow at least 5 seconds for remaining procs to terminate.
|
||||
timeout = max(timeout, 5.0)
|
||||
|
||||
# Shutdown the process.
|
||||
for proc in procs:
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
|
||||
# Allow 5 seconds for remaining procs to terminate.
|
||||
deadline = time.monotonic() + 5
|
||||
# Allow time for remaining procs to terminate.
|
||||
deadline = time.monotonic() + timeout
|
||||
for proc in procs:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
|
||||
Reference in New Issue
Block a user