diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 7d0b513aa..1d96b05ac 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -526,6 +526,7 @@ class MockModelConfig: allowed_media_domains: list[str] | None = None encoder_config = None generation_config: str = "auto" + override_generation_config: dict[str, Any] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) skip_tokenizer_init: bool = False is_encoder_decoder: bool = False @@ -651,12 +652,10 @@ async def test_serving_chat_should_set_correct_max_tokens(): assert mock_engine.generate.call_args.args[1].max_tokens == 10 - # Setting server's max_tokens in the generation_config.json - # lower than context_window - prompt_tokens + # Model author's generation_config.json sets max_tokens (auto, no override) + # — should act as fallback only, not ceiling mock_model_config = MockModelConfig() - mock_model_config.diff_sampling_param = { - "max_tokens": 10 # Setting server-side max_tokens limit - } + mock_model_config.diff_sampling_param = {"max_tokens": 10} # Reinitialize the engine with new settings mock_engine = MagicMock(spec=AsyncLLM) @@ -680,7 +679,50 @@ async def test_serving_chat_should_set_correct_max_tokens(): assert mock_engine.generate.call_args.args[1].max_tokens == 10 - # Test Case 2: Request's max_tokens set higher than server accepts + # Test Case 2: Request's max_tokens set higher than generation_config + # default so request-provided max_tokens takes precedence + req.max_tokens = 15 + + with suppress(Exception): + await serving_chat.create_chat_completion(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 15 + + # Test Case 3: Request's max_tokens set lower than server accepts + req.max_tokens = 5 + + with suppress(Exception): + await serving_chat.create_chat_completion(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 5 + + # User explicitly sets max_tokens via --override-generation-config + # — should act as a ceiling + mock_model_config = MockModelConfig() + mock_model_config.diff_sampling_param = {"max_tokens": 10} + mock_model_config.override_generation_config = {"max_new_tokens": 10} + + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = mock_model_config + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + serving_chat = _build_serving_chat(mock_engine) + + # Test Case 3.1: No max_tokens — uses override as default + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "what is 1+1?"}], + ) + + with suppress(Exception): + await serving_chat.create_chat_completion(req) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 + + # Test Case 3.2: Request max_tokens higher — capped by user ceiling from override req.max_tokens = 15 with suppress(Exception): @@ -688,7 +730,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): assert mock_engine.generate.call_args.args[1].max_tokens == 10 - # Test Case 3: Request's max_tokens set lower than server accepts + # Test Case 3.3: Request max_tokens lower — respected req.max_tokens = 5 with suppress(Exception): @@ -699,9 +741,7 @@ async def test_serving_chat_should_set_correct_max_tokens(): # Setting server's max_tokens in the generation_config.json # higher than context_window - prompt_tokens mock_model_config = MockModelConfig() - mock_model_config.diff_sampling_param = { - "max_tokens": 200 # Setting server-side max_tokens limit - } + mock_model_config.diff_sampling_param = {"max_tokens": 200} # Reinitialize the engine with new settings mock_engine = MagicMock(spec=AsyncLLM) diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index dc1101840..e071bacb7 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.entrypoints.utils import sanitize_message + +from vllm.entrypoints.utils import get_max_tokens, sanitize_message def test_sanitize_message(): @@ -8,3 +9,74 @@ def test_sanitize_message(): sanitize_message("<_io.BytesIO object at 0x7a95e299e750>") == "<_io.BytesIO object>" ) + + +class TestGetMaxTokens: + """Tests for get_max_tokens() to ensure generation_config's max_tokens + acts as a default when from model author, and as a ceiling when + explicitly set by the user.""" + + def test_default_sampling_params_used_when_no_request_max_tokens(self): + """When user doesn't specify max_tokens, generation_config default + should apply.""" + result = get_max_tokens( + max_model_len=24000, + max_tokens=None, + input_length=100, + default_sampling_params={"max_tokens": 2048}, + ) + assert result == 2048 + + def test_request_max_tokens_not_capped_by_default_sampling_params(self): + """When user specifies max_tokens in request, model author's + generation_config max_tokens must NOT cap it (fixes #34005).""" + result = get_max_tokens( + max_model_len=24000, + max_tokens=5000, + input_length=100, + default_sampling_params={"max_tokens": 2048}, + ) + assert result == 5000 + + def test_override_max_tokens_caps_request(self): + """When user explicitly sets max_tokens, it acts as a ceiling.""" + result = get_max_tokens( + max_model_len=24000, + max_tokens=5000, + input_length=100, + default_sampling_params={"max_tokens": 2048}, + override_max_tokens=2048, + ) + assert result == 2048 + + def test_override_max_tokens_used_as_default(self): + """When no request max_tokens, override still applies as default.""" + result = get_max_tokens( + max_model_len=24000, + max_tokens=None, + input_length=100, + default_sampling_params={"max_tokens": 2048}, + override_max_tokens=2048, + ) + assert result == 2048 + + def test_max_model_len_still_caps_output(self): + """max_model_len - input_length is always the hard ceiling.""" + result = get_max_tokens( + max_model_len=3000, + max_tokens=5000, + input_length=100, + default_sampling_params={"max_tokens": 2048}, + ) + assert result == 2900 # 3000 - 100 + + def test_request_max_tokens_smaller_than_default(self): + """When user explicitly requests fewer tokens than gen_config default, + that should be respected.""" + result = get_max_tokens( + max_model_len=24000, + max_tokens=512, + input_length=100, + default_sampling_params={"max_tokens": 2048}, + ) + assert result == 512 diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 7b54e6daf..f1523cdc6 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -145,6 +145,12 @@ class OpenAIServingChat(OpenAIServing): self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = self.model_config.get_diff_sampling_param() + mc = self.model_config + self.override_max_tokens = ( + self.default_sampling_params.get("max_tokens") + if mc.generation_config not in ("auto", "vllm") + else getattr(mc, "override_generation_config", {}).get("max_new_tokens") + ) self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: if "stop_token_ids" not in self.default_sampling_params: @@ -389,6 +395,7 @@ class OpenAIServingChat(OpenAIServing): else request.max_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, + self.override_max_tokens, ) sampling_params: SamplingParams | BeamSearchParams diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 994cc094a..acbb95868 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -70,6 +70,12 @@ class OpenAIServingCompletion(OpenAIServing): self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = self.model_config.get_diff_sampling_param() + mc = self.model_config + self.override_max_tokens = ( + self.default_sampling_params.get("max_tokens") + if mc.generation_config not in ("auto", "vllm") + else getattr(mc, "override_generation_config", {}).get("max_new_tokens") + ) async def render_completion_request( self, @@ -164,6 +170,7 @@ class OpenAIServingCompletion(OpenAIServing): request.max_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, + self.override_max_tokens, ) sampling_params: SamplingParams | BeamSearchParams diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 1484fca5b..d99daf739 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1174,6 +1174,7 @@ class OpenAIServing: context.request.max_output_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, # type: ignore + self.override_max_tokens, # type: ignore ) # OPTIMIZATION diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 0d9ef135a..39dd2fb79 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -229,6 +229,12 @@ class OpenAIServingResponses(OpenAIServing): self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = self.model_config.get_diff_sampling_param() + mc = self.model_config + self.override_max_tokens = ( + self.default_sampling_params.get("max_tokens") + if mc.generation_config not in ("auto", "vllm") + else getattr(mc, "override_generation_config", {}).get("max_new_tokens") + ) # If False (default), the "store" option is (silently) ignored and the # response is not stored. If True, the response is stored in memory. @@ -446,6 +452,7 @@ class OpenAIServingResponses(OpenAIServing): request.max_output_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, + self.override_max_tokens, ) sampling_params = request.to_sampling_params( diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 98822b9c6..34df85f37 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -177,17 +177,23 @@ def get_max_tokens( max_tokens: int | None, input_length: int, default_sampling_params: dict, + override_max_tokens: int | None = None, ) -> int: - default_max_tokens = max_model_len - input_length - max_output_tokens = current_platform.get_max_output_tokens(input_length) + model_max_tokens = max_model_len - input_length + platform_max_tokens = current_platform.get_max_output_tokens(input_length) + fallback_max_tokens = ( + max_tokens + if max_tokens is not None + else default_sampling_params.get("max_tokens") + ) return min( val for val in ( - default_max_tokens, - max_tokens, - max_output_tokens, - default_sampling_params.get("max_tokens"), + model_max_tokens, + fallback_max_tokens, + override_max_tokens, + platform_max_tokens, ) if val is not None )