diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index e071bacb7..725938339 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + from vllm.entrypoints.utils import get_max_tokens, sanitize_message @@ -80,3 +82,15 @@ class TestGetMaxTokens: default_sampling_params={"max_tokens": 2048}, ) assert result == 512 + + def test_input_length_exceeds_max_model_len(self): + with pytest.raises( + ValueError, + match="Input length .* exceeds model's maximum context length .*", + ): + get_max_tokens( + max_model_len=100, + max_tokens=50, + input_length=150, + default_sampling_params={"max_tokens": 2048}, + ) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 7c158a17c..9550a41bb 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -178,6 +178,11 @@ def get_max_tokens( default_sampling_params: dict, override_max_tokens: int | None = None, ) -> int: + if max_model_len < input_length: + raise ValueError( + f"Input length ({input_length}) exceeds model's maximum " + f"context length ({max_model_len})." + ) model_max_tokens = max_model_len - input_length platform_max_tokens = current_platform.get_max_output_tokens(input_length) fallback_max_tokens = (