Default to generation_config from model (#12622)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-03-08 07:46:15 +01:00
committed by GitHub
parent 3b9c6c6947
commit 47512b3200
7 changed files with 27 additions and 26 deletions

View File

@@ -255,7 +255,7 @@ class ModelConfig:
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
@@ -951,7 +951,7 @@ class ModelConfig:
return self.multimodal_config
def try_get_generation_config(self) -> dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
if self.generation_config in ("auto", "vllm"):
config = try_get_generation_config(
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
@@ -971,17 +971,14 @@ class ModelConfig:
def get_diff_sampling_param(self) -> dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
that differ from the default sampling parameters. If
`generation_config` is `"vllm"`, an empty dictionary is returned.
Returns:
dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
parameters, if `generation_config` is `"vllm"` an empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
if self.generation_config == "vllm":
config = {}
else:
config = self.try_get_generation_config()