Default to generation_config from model (#12622)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-03-08 07:46:15 +01:00
committed by GitHub
parent 3b9c6c6947
commit 47512b3200
7 changed files with 27 additions and 26 deletions

View File

@@ -289,7 +289,7 @@ def test_uses_mrope(model_id, uses_mrope):
def test_generation_config_loading():
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
# When set generation_config to None, the default generation config
# When set generation_config to "vllm", the default generation config
# will not be loaded.
model_config = ModelConfig(model_id,
task="auto",
@@ -298,7 +298,7 @@ def test_generation_config_loading():
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None)
generation_config="vllm")
assert model_config.get_diff_sampling_param() == {}
# When set generation_config to "auto", the default generation config
@@ -340,7 +340,7 @@ def test_generation_config_loading():
assert model_config.get_diff_sampling_param() == override_result
# When generation_config is set to None and override_generation_config
# When generation_config is set to "vllm" and override_generation_config
# is set, the override_generation_config should be used directly.
model_config = ModelConfig(
model_id,
@@ -350,7 +350,7 @@ def test_generation_config_loading():
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None,
generation_config="vllm",
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config