[BugFix] --max-model-len=-1 causes over-limit requests to hang and starve the entire service (#39102)
Signed-off-by: triangle14 <y1019026570@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ import pytest
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@@ -61,3 +62,42 @@ def test_decoder_max_context_length_validation(
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
vllm_model.generate_greedy(prompt_ids, max_tokens)
|
||||
assert expected_msg in str(excinfo.value)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
|
||||
def test_auto_fit_max_model_len_rejects_oversized_input(
|
||||
model: str,
|
||||
vllm_runner: type[VllmRunner],
|
||||
) -> None:
|
||||
"""When max_model_len='auto' and KV cache memory is very limited,
|
||||
the engine auto-fits max_model_len to a small value. The frontend
|
||||
must see this reduced value and reject prompts that exceed it,
|
||||
rather than accepting them and hanging."""
|
||||
|
||||
# Use a tiny KV cache budget to force auto-fit to a very small
|
||||
# max_model_len (e.g. ~16 tokens).
|
||||
kv_cache_bytes = 1_000_000 # 1 MB
|
||||
|
||||
with vllm_runner(
|
||||
model_name=model,
|
||||
max_model_len=-1,
|
||||
max_num_seqs=1,
|
||||
enforce_eager=True,
|
||||
kv_cache_memory_bytes=kv_cache_bytes,
|
||||
load_format="dummy",
|
||||
) as vllm_model:
|
||||
auto_fitted_len = (
|
||||
vllm_model.llm.llm_engine.vllm_config.model_config.max_model_len
|
||||
)
|
||||
# Sanity check: auto-fit should have reduced it well below the
|
||||
# model's native context length.
|
||||
assert auto_fitted_len < 2048, (
|
||||
f"Expected auto-fit to reduce max_model_len significantly, "
|
||||
f"but got {auto_fitted_len}"
|
||||
)
|
||||
|
||||
# A prompt longer than the auto-fitted length must be rejected.
|
||||
oversized_prompt = [[43] * (auto_fitted_len + 10)]
|
||||
with pytest.raises(VLLMValidationError, match="Please reduce the length"):
|
||||
vllm_model.generate_greedy(oversized_prompt, max_tokens=4)
|
||||
|
||||
@@ -114,7 +114,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
return 1
|
||||
|
||||
def recv_multipart(self):
|
||||
return (b"\x00\x00", b"ready")
|
||||
return (b"\x00\x00", b"")
|
||||
|
||||
class DummySocket:
|
||||
def send_multipart(self, _msg, *, copy: bool = False, track: bool = False):
|
||||
|
||||
@@ -63,6 +63,20 @@ class FinishReason(enum.IntEnum):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreReadyResponse(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
):
|
||||
"""Sent from EngineCore to the frontend during the ready handshake.
|
||||
|
||||
Contains post-initialization config that may differ from the original
|
||||
values (e.g. max_model_len after KV cache auto-fitting).
|
||||
"""
|
||||
|
||||
max_model_len: int | None = None
|
||||
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
|
||||
@@ -52,6 +52,7 @@ from vllm.v1.engine import (
|
||||
EEPNotificationType,
|
||||
EngineCoreOutput,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreReadyResponse,
|
||||
EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
FinishReason,
|
||||
@@ -1385,11 +1386,15 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# Register sockets with poller.
|
||||
poller = zmq.Poller()
|
||||
ready_response = EngineCoreReadyResponse(
|
||||
max_model_len=self.vllm_config.model_config.max_model_len,
|
||||
)
|
||||
ready_payload = msgspec.msgpack.encode(ready_response)
|
||||
for input_socket in input_sockets:
|
||||
# Send initial message to each input socket - this is required
|
||||
# before the front-end ROUTER socket can send input messages
|
||||
# back to us.
|
||||
input_socket.send(b"")
|
||||
input_socket.send(ready_payload)
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
|
||||
if coord_socket is not None:
|
||||
|
||||
@@ -35,6 +35,7 @@ from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreReadyResponse,
|
||||
EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
PauseMode,
|
||||
@@ -456,6 +457,22 @@ class ElasticScalingCache:
|
||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||
|
||||
|
||||
_ready_response_decoder = msgspec.msgpack.Decoder(EngineCoreReadyResponse)
|
||||
|
||||
|
||||
def _apply_ready_response(payload: bytes, vllm_config: VllmConfig) -> None:
|
||||
"""Decode an EngineCoreReadyResponse and sync any post-initialization
|
||||
config changes (e.g. auto-fitted max_model_len) back to the frontend."""
|
||||
if not payload:
|
||||
return
|
||||
response = _ready_response_decoder.decode(payload)
|
||||
if response.max_model_len is not None:
|
||||
vllm_config.model_config.max_model_len = min(
|
||||
vllm_config.model_config.max_model_len,
|
||||
response.max_model_len,
|
||||
)
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
@@ -589,8 +606,9 @@ class MPClient(EngineCoreClient):
|
||||
f"timeout, set the environment variable: "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
|
||||
)
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
identity, payload = sync_input_socket.recv_multipart()
|
||||
identities.remove(identity)
|
||||
_apply_ready_response(payload, vllm_config)
|
||||
|
||||
self.core_engine: EngineIdentity = self.core_engines[0]
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
@@ -1582,8 +1600,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
f"timeout, set the environment variable: "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
|
||||
)
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
identity, payload = sync_input_socket.recv_multipart()
|
||||
new_engine_identities.discard(identity)
|
||||
_apply_ready_response(payload, self.vllm_config)
|
||||
|
||||
# NOTE(yongji): Before we schedule any requests on the new workers,
|
||||
# we should wait for them to switch to the new setup.
|
||||
|
||||
Reference in New Issue
Block a user