[Bugfix] Treat generation_config max_tokens as default not ceiling (#34063)
Signed-off-by: almogtavor <almogtavor@gmail.com>
This commit is contained in:
@@ -526,6 +526,7 @@ class MockModelConfig:
|
||||
allowed_media_domains: list[str] | None = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
override_generation_config: dict[str, Any] = field(default_factory=dict)
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init: bool = False
|
||||
is_encoder_decoder: bool = False
|
||||
@@ -651,12 +652,10 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# lower than context_window - prompt_tokens
|
||||
# Model author's generation_config.json sets max_tokens (auto, no override)
|
||||
# — should act as fallback only, not ceiling
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 10 # Setting server-side max_tokens limit
|
||||
}
|
||||
mock_model_config.diff_sampling_param = {"max_tokens": 10}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
@@ -680,7 +679,50 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 2: Request's max_tokens set higher than server accepts
|
||||
# Test Case 2: Request's max_tokens set higher than generation_config
|
||||
# default so request-provided max_tokens takes precedence
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 15
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 5
|
||||
|
||||
# User explicitly sets max_tokens via --override-generation-config
|
||||
# — should act as a ceiling
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {"max_tokens": 10}
|
||||
mock_model_config.override_generation_config = {"max_new_tokens": 10}
|
||||
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test Case 3.1: No max_tokens — uses override as default
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": "what is 1+1?"}],
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 3.2: Request max_tokens higher — capped by user ceiling from override
|
||||
req.max_tokens = 15
|
||||
|
||||
with suppress(Exception):
|
||||
@@ -688,7 +730,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
assert mock_engine.generate.call_args.args[1].max_tokens == 10
|
||||
|
||||
# Test Case 3: Request's max_tokens set lower than server accepts
|
||||
# Test Case 3.3: Request max_tokens lower — respected
|
||||
req.max_tokens = 5
|
||||
|
||||
with suppress(Exception):
|
||||
@@ -699,9 +741,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
# Setting server's max_tokens in the generation_config.json
|
||||
# higher than context_window - prompt_tokens
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.diff_sampling_param = {
|
||||
"max_tokens": 200 # Setting server-side max_tokens limit
|
||||
}
|
||||
mock_model_config.diff_sampling_param = {"max_tokens": 200}
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
|
||||
Reference in New Issue
Block a user