Fix hf_override_fn when it modifies model_type (#35200)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-10 15:03:18 +00:00
committed by GitHub
parent 106ff69c4e
commit d88f28da05
2 changed files with 15 additions and 6 deletions

View File

@@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase):
) )
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY # Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None: if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
model_type = hf_overrides.get("model_type", model_type) if isinstance(hf_overrides, dict) and "model_type" in hf_overrides:
model_type = hf_overrides["model_type"]
elif callable(hf_overrides):
# If hf_overrides doesn't modify model_type, it will be passed straight
# through and remain unchanged by this elif block
dummy_model_type = f"dummy_{model_type}"
dummy_kwargs = dict(architectures=[""], model_type=dummy_model_type)
dummy_config = PretrainedConfig(**dummy_kwargs)
dummy_model_type = hf_overrides(dummy_config).model_type
model_type = dummy_model_type.removeprefix("dummy_")
if model_type in _CONFIG_REGISTRY: if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type] config_class = _CONFIG_REGISTRY[model_type]
@@ -634,7 +643,7 @@ def get_config(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
hf_overrides=hf_overrides_kw, hf_overrides=hf_overrides_kw or hf_overrides_fn,
**kwargs, **kwargs,
) )

View File

@@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase:
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None: if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head return self.hf_text_config.hidden_size_per_head
if (total_num_attention_heads := self.get_total_num_attention_heads()) == 0:
return 0
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return ( return self.get_hidden_size() // total_num_attention_heads
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
)
def get_total_num_kv_heads(self) -> int: def get_total_num_kv_heads(self) -> int:
attributes = [ attributes = [
@@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase:
] ]
# For non-grouped-query attention models, the number of KV heads is # For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads. # equal to the number of attention heads.
default_factory = lambda: self.hf_text_config.num_attention_heads default_factory = self.get_total_num_attention_heads
return getattr_iter( return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory self.hf_text_config, attributes, default_factory=default_factory
) )