[Frontend] Support override generation config in args (#12409)
Signed-off-by: liuyanyi <wolfsonliu@163.com>
This commit is contained in:
@@ -281,3 +281,73 @@ def test_uses_mrope(model_id, uses_mrope):
|
||||
)
|
||||
|
||||
assert config.uses_mrope == uses_mrope
|
||||
|
||||
|
||||
def test_generation_config_loading():
|
||||
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
# When set generation_config to None, the default generation config
|
||||
# will not be loaded.
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config=None)
|
||||
assert model_config.get_diff_sampling_param() == {}
|
||||
|
||||
# When set generation_config to "auto", the default generation config
|
||||
# should be loaded.
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config="auto")
|
||||
|
||||
correct_generation_config = {
|
||||
"repetition_penalty": 1.1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.8,
|
||||
"top_k": 20,
|
||||
}
|
||||
|
||||
assert model_config.get_diff_sampling_param() == correct_generation_config
|
||||
|
||||
# The generation config could be overridden by the user.
|
||||
override_generation_config = {"temperature": 0.5, "top_k": 5}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config="auto",
|
||||
override_generation_config=override_generation_config)
|
||||
|
||||
override_result = correct_generation_config.copy()
|
||||
override_result.update(override_generation_config)
|
||||
|
||||
assert model_config.get_diff_sampling_param() == override_result
|
||||
|
||||
# When generation_config is set to None and override_generation_config
|
||||
# is set, the override_generation_config should be used directly.
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config=None,
|
||||
override_generation_config=override_generation_config)
|
||||
|
||||
assert model_config.get_diff_sampling_param() == override_generation_config
|
||||
|
||||
Reference in New Issue
Block a user