[BugFix] Fix engine hanging after KV cache initialization failure (#35478)
Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
@@ -117,9 +118,17 @@ class EngineCore:
|
||||
self._eep_scale_up_before_kv_init()
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
|
||||
vllm_config
|
||||
)
|
||||
try:
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = (
|
||||
self._initialize_kv_caches(vllm_config)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"EngineCore failed during KV cache initialization; "
|
||||
"shutting down executor."
|
||||
)
|
||||
self.model_executor.shutdown()
|
||||
raise
|
||||
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
@@ -958,29 +967,49 @@ class EngineCoreProc(EngineCore):
|
||||
addresses = self.startup_handshake(
|
||||
handshake_socket, local_client, headless, parallel_config_to_update
|
||||
)
|
||||
yield addresses
|
||||
exc_during_init = False
|
||||
try:
|
||||
yield addresses
|
||||
except Exception:
|
||||
exc_during_init = True
|
||||
raise
|
||||
finally:
|
||||
if exc_during_init:
|
||||
# Send FAILED status so the front-end detects init
|
||||
# failure immediately via ZMQ instead of waiting for
|
||||
# process sentinel (which may be delayed by cleanup).
|
||||
with contextlib.suppress(Exception):
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode(
|
||||
{
|
||||
"status": "FAILED",
|
||||
"local": local_client,
|
||||
"headless": headless,
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
# We pass back the coordinator stats update address
|
||||
# here for the external LB case for our colocated
|
||||
# front-end to use (coordinator only runs with rank 0).
|
||||
dp_stats_address = self.frontend_stats_publish_address
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
# We pass back the coordinator stats update address here for the
|
||||
# external LB case for our colocated front-end to use (coordinator
|
||||
# only runs with rank 0).
|
||||
dp_stats_address = self.frontend_stats_publish_address
|
||||
# Include config hash for DP configuration validation
|
||||
ready_msg = {
|
||||
"status": "READY",
|
||||
"local": local_client,
|
||||
"headless": headless,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
"dp_stats_address": dp_stats_address,
|
||||
}
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
ready_msg["parallel_config_hash"] = (
|
||||
vllm_config.parallel_config.compute_hash()
|
||||
)
|
||||
|
||||
# Include config hash for DP configuration validation
|
||||
ready_msg = {
|
||||
"status": "READY",
|
||||
"local": local_client,
|
||||
"headless": headless,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
"dp_stats_address": dp_stats_address,
|
||||
}
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
ready_msg["parallel_config_hash"] = (
|
||||
vllm_config.parallel_config.compute_hash()
|
||||
)
|
||||
|
||||
handshake_socket.send(msgspec.msgpack.encode(ready_msg))
|
||||
handshake_socket.send(msgspec.msgpack.encode(ready_msg))
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(
|
||||
|
||||
@@ -1101,6 +1101,11 @@ def wait_for_engine_startup(
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
elif status == "FAILED":
|
||||
raise RuntimeError(
|
||||
f"Engine core {eng_index} reported initialization failure. "
|
||||
"See root cause above."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unexpected {status} message for "
|
||||
|
||||
Reference in New Issue
Block a user