[Bugfix] Treat generation_config max_tokens as default not ceiling (#34063)

Signed-off-by: almogtavor <almogtavor@gmail.com>
This commit is contained in:
Almog Tavor
2026-02-16 17:58:24 +02:00
committed by GitHub
parent a3205beffb
commit 72d5951d02
7 changed files with 157 additions and 17 deletions

View File

@@ -526,6 +526,7 @@ class MockModelConfig:
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
override_generation_config: dict[str, Any] = field(default_factory=dict)
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@@ -651,12 +652,10 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Setting server's max_tokens in the generation_config.json
# lower than context_window - prompt_tokens
# Model author's generation_config.json sets max_tokens (auto, no override)
# — should act as fallback only, not ceiling
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 10 # Setting server-side max_tokens limit
}
mock_model_config.diff_sampling_param = {"max_tokens": 10}
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM)
@@ -680,7 +679,50 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Test Case 2: Request's max_tokens set higher than server accepts
# Test Case 2: Request's max_tokens set higher than generation_config
# default so request-provided max_tokens takes precedence
req.max_tokens = 15
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 15
# Test Case 3: Request's max_tokens set lower than server accepts
req.max_tokens = 5
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 5
# User explicitly sets max_tokens via --override-generation-config
# — should act as a ceiling
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {"max_tokens": 10}
mock_model_config.override_generation_config = {"max_new_tokens": 10}
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
# Test Case 3.1: No max_tokens — uses override as default
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
)
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Test Case 3.2: Request max_tokens higher — capped by user ceiling from override
req.max_tokens = 15
with suppress(Exception):
@@ -688,7 +730,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
assert mock_engine.generate.call_args.args[1].max_tokens == 10
# Test Case 3: Request's max_tokens set lower than server accepts
# Test Case 3.3: Request max_tokens lower — respected
req.max_tokens = 5
with suppress(Exception):
@@ -699,9 +741,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
# Setting server's max_tokens in the generation_config.json
# higher than context_window - prompt_tokens
mock_model_config = MockModelConfig()
mock_model_config.diff_sampling_param = {
"max_tokens": 200 # Setting server-side max_tokens limit
}
mock_model_config.diff_sampling_param = {"max_tokens": 200}
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM)