diff --git a/tests/v1/e2e/general/test_context_length.py b/tests/v1/e2e/general/test_context_length.py index 0ac40bec3..c9dc8354f 100644 --- a/tests/v1/e2e/general/test_context_length.py +++ b/tests/v1/e2e/general/test_context_length.py @@ -15,6 +15,7 @@ import pytest from tests.conftest import VllmRunner from tests.utils import create_new_process_for_each_test +from vllm.exceptions import VLLMValidationError @create_new_process_for_each_test() @@ -61,3 +62,42 @@ def test_decoder_max_context_length_validation( with pytest.raises(ValueError) as excinfo: vllm_model.generate_greedy(prompt_ids, max_tokens) assert expected_msg in str(excinfo.value) + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model", ["JackFram/llama-160m"]) +def test_auto_fit_max_model_len_rejects_oversized_input( + model: str, + vllm_runner: type[VllmRunner], +) -> None: + """When max_model_len='auto' and KV cache memory is very limited, + the engine auto-fits max_model_len to a small value. The frontend + must see this reduced value and reject prompts that exceed it, + rather than accepting them and hanging.""" + + # Use a tiny KV cache budget to force auto-fit to a very small + # max_model_len (e.g. ~16 tokens). + kv_cache_bytes = 1_000_000 # 1 MB + + with vllm_runner( + model_name=model, + max_model_len=-1, + max_num_seqs=1, + enforce_eager=True, + kv_cache_memory_bytes=kv_cache_bytes, + load_format="dummy", + ) as vllm_model: + auto_fitted_len = ( + vllm_model.llm.llm_engine.vllm_config.model_config.max_model_len + ) + # Sanity check: auto-fit should have reduced it well below the + # model's native context length. + assert auto_fitted_len < 2048, ( + f"Expected auto-fit to reduce max_model_len significantly, " + f"but got {auto_fitted_len}" + ) + + # A prompt longer than the auto-fitted length must be rejected. + oversized_prompt = [[43] * (auto_fitted_len + 10)] + with pytest.raises(VLLMValidationError, match="Please reduce the length"): + vllm_model.generate_greedy(oversized_prompt, max_tokens=4) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 5e08ae35f..d2a9ceb2d 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -114,7 +114,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch): return 1 def recv_multipart(self): - return (b"\x00\x00", b"ready") + return (b"\x00\x00", b"") class DummySocket: def send_multipart(self, _msg, *, copy: bool = False, track: bool = False): diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 114d45fc4..344690e73 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -63,6 +63,20 @@ class FinishReason(enum.IntEnum): return FINISH_REASON_STRINGS[self.value] +class EngineCoreReadyResponse( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] +): + """Sent from EngineCore to the frontend during the ready handshake. + + Contains post-initialization config that may differ from the original + values (e.g. max_model_len after KV cache auto-fitting). + """ + + max_model_len: int | None = None + + class EngineCoreRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ff5f924a1..fbf956265 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -52,6 +52,7 @@ from vllm.v1.engine import ( EEPNotificationType, EngineCoreOutput, EngineCoreOutputs, + EngineCoreReadyResponse, EngineCoreRequest, EngineCoreRequestType, FinishReason, @@ -1385,11 +1386,15 @@ class EngineCoreProc(EngineCore): # Register sockets with poller. poller = zmq.Poller() + ready_response = EngineCoreReadyResponse( + max_model_len=self.vllm_config.model_config.max_model_len, + ) + ready_payload = msgspec.msgpack.encode(ready_response) for input_socket in input_sockets: # Send initial message to each input socket - this is required # before the front-end ROUTER socket can send input messages # back to us. - input_socket.send(b"") + input_socket.send(ready_payload) poller.register(input_socket, zmq.POLLIN) if coord_socket is not None: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 1d73c12ed..c4e631a41 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -35,6 +35,7 @@ from vllm.v1.engine import ( EEP_NOTIFICATION_CALL_ID, EEPNotificationType, EngineCoreOutputs, + EngineCoreReadyResponse, EngineCoreRequest, EngineCoreRequestType, PauseMode, @@ -456,6 +457,22 @@ class ElasticScalingCache: pending_notifications: dict[EEPNotificationType, set[int]] +_ready_response_decoder = msgspec.msgpack.Decoder(EngineCoreReadyResponse) + + +def _apply_ready_response(payload: bytes, vllm_config: VllmConfig) -> None: + """Decode an EngineCoreReadyResponse and sync any post-initialization + config changes (e.g. auto-fitted max_model_len) back to the frontend.""" + if not payload: + return + response = _ready_response_decoder.decode(payload) + if response.max_model_len is not None: + vllm_config.model_config.max_model_len = min( + vllm_config.model_config.max_model_len, + response.max_model_len, + ) + + class MPClient(EngineCoreClient): """ MPClient: base client for multi-proc EngineCore. @@ -589,8 +606,9 @@ class MPClient(EngineCoreClient): f"timeout, set the environment variable: " f"VLLM_ENGINE_READY_TIMEOUT_S=" ) - identity, _ = sync_input_socket.recv_multipart() + identity, payload = sync_input_socket.recv_multipart() identities.remove(identity) + _apply_ready_response(payload, vllm_config) self.core_engine: EngineIdentity = self.core_engines[0] self.utility_results: dict[int, AnyFuture] = {} @@ -1582,8 +1600,9 @@ class DPLBAsyncMPClient(DPAsyncMPClient): f"timeout, set the environment variable: " f"VLLM_ENGINE_READY_TIMEOUT_S=" ) - identity, _ = sync_input_socket.recv_multipart() + identity, payload = sync_input_socket.recv_multipart() new_engine_identities.discard(identity) + _apply_ready_response(payload, self.vllm_config) # NOTE(yongji): Before we schedule any requests on the new workers, # we should wait for them to switch to the new setup.