From d88f28da05b12bc7d63ebe3dcedf445ecb274343 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:03:18 +0000 Subject: [PATCH] Fix `hf_override_fn` when it modifies `model_type` (#35200) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/transformers_utils/config.py | 13 +++++++++++-- .../model_arch_config_convertor.py | 8 ++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 99d8b5dcc..dd22ed544 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase): ) # Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: - model_type = hf_overrides.get("model_type", model_type) + if isinstance(hf_overrides, dict) and "model_type" in hf_overrides: + model_type = hf_overrides["model_type"] + elif callable(hf_overrides): + # If hf_overrides doesn't modify model_type, it will be passed straight + # through and remain unchanged by this elif block + dummy_model_type = f"dummy_{model_type}" + dummy_kwargs = dict(architectures=[""], model_type=dummy_model_type) + dummy_config = PretrainedConfig(**dummy_kwargs) + dummy_model_type = hf_overrides(dummy_config).model_type + model_type = dummy_model_type.removeprefix("dummy_") if model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[model_type] @@ -634,7 +643,7 @@ def get_config( trust_remote_code=trust_remote_code, revision=revision, code_revision=code_revision, - hf_overrides=hf_overrides_kw, + hf_overrides=hf_overrides_kw or hf_overrides_fn, **kwargs, ) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index bb45f137e..4444469dc 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase: if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: return self.hf_text_config.hidden_size_per_head + if (total_num_attention_heads := self.get_total_num_attention_heads()) == 0: + return 0 # FIXME(woosuk): This may not be true for all models. - return ( - self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads - ) + return self.get_hidden_size() // total_num_attention_heads def get_total_num_kv_heads(self) -> int: attributes = [ @@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase: ] # For non-grouped-query attention models, the number of KV heads is # equal to the number of attention heads. - default_factory = lambda: self.hf_text_config.num_attention_heads + default_factory = self.get_total_num_attention_heads return getattr_iter( self.hf_text_config, attributes, default_factory=default_factory )