[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config (#19660)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
committed by
GitHub
parent
367871a469
commit
b692e9cd07
@@ -438,3 +438,31 @@ def test_load_config_pt_load_map_location(pt_load_map_location):
|
||||
config = VllmConfig(load_config=load_config)
|
||||
|
||||
assert config.load_config.pt_load_map_location == pt_load_map_location
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "expected_max_len", "should_raise"), [
|
||||
("BAAI/bge-reranker-base", None, 512, False),
|
||||
("BAAI/bge-reranker-base", 256, 256, False),
|
||||
("BAAI/bge-reranker-base", 513, 512, True),
|
||||
])
|
||||
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
|
||||
should_raise):
|
||||
"""Test get_and_verify_max_len with different configurations."""
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
model_config.get_and_verify_max_len(max_model_len)
|
||||
else:
|
||||
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
|
||||
assert actual_max_len == expected_max_len
|
||||
|
||||
Reference in New Issue
Block a user