[Bugfix] Treat generation_config max_tokens as default not ceiling (#34063)

Signed-off-by: almogtavor <almogtavor@gmail.com>
This commit is contained in:
Almog Tavor
2026-02-16 17:58:24 +02:00
committed by GitHub
parent a3205beffb
commit 72d5951d02
7 changed files with 157 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.utils import sanitize_message
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
def test_sanitize_message():
@@ -8,3 +9,74 @@ def test_sanitize_message():
sanitize_message("<_io.BytesIO object at 0x7a95e299e750>")
== "<_io.BytesIO object>"
)
class TestGetMaxTokens:
"""Tests for get_max_tokens() to ensure generation_config's max_tokens
acts as a default when from model author, and as a ceiling when
explicitly set by the user."""
def test_default_sampling_params_used_when_no_request_max_tokens(self):
"""When user doesn't specify max_tokens, generation_config default
should apply."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=None,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 2048
def test_request_max_tokens_not_capped_by_default_sampling_params(self):
"""When user specifies max_tokens in request, model author's
generation_config max_tokens must NOT cap it (fixes #34005)."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=5000,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 5000
def test_override_max_tokens_caps_request(self):
"""When user explicitly sets max_tokens, it acts as a ceiling."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=5000,
input_length=100,
default_sampling_params={"max_tokens": 2048},
override_max_tokens=2048,
)
assert result == 2048
def test_override_max_tokens_used_as_default(self):
"""When no request max_tokens, override still applies as default."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=None,
input_length=100,
default_sampling_params={"max_tokens": 2048},
override_max_tokens=2048,
)
assert result == 2048
def test_max_model_len_still_caps_output(self):
"""max_model_len - input_length is always the hard ceiling."""
result = get_max_tokens(
max_model_len=3000,
max_tokens=5000,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 2900 # 3000 - 100
def test_request_max_tokens_smaller_than_default(self):
"""When user explicitly requests fewer tokens than gen_config default,
that should be respected."""
result = get_max_tokens(
max_model_len=24000,
max_tokens=512,
input_length=100,
default_sampling_params={"max_tokens": 2048},
)
assert result == 512