Fix hf_override_fn when it modifies model_type (#35200)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-10 15:03:18 +00:00
committed by GitHub
parent 106ff69c4e
commit d88f28da05
2 changed files with 15 additions and 6 deletions

View File

@@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase):
)
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
model_type = hf_overrides.get("model_type", model_type)
if isinstance(hf_overrides, dict) and "model_type" in hf_overrides:
model_type = hf_overrides["model_type"]
elif callable(hf_overrides):
# If hf_overrides doesn't modify model_type, it will be passed straight
# through and remain unchanged by this elif block
dummy_model_type = f"dummy_{model_type}"
dummy_kwargs = dict(architectures=[""], model_type=dummy_model_type)
dummy_config = PretrainedConfig(**dummy_kwargs)
dummy_model_type = hf_overrides(dummy_config).model_type
model_type = dummy_model_type.removeprefix("dummy_")
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
@@ -634,7 +643,7 @@ def get_config(
trust_remote_code=trust_remote_code,
revision=revision,
code_revision=code_revision,
hf_overrides=hf_overrides_kw,
hf_overrides=hf_overrides_kw or hf_overrides_fn,
**kwargs,
)

View File

@@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase:
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
return self.hf_text_config.hidden_size_per_head
if (total_num_attention_heads := self.get_total_num_attention_heads()) == 0:
return 0
# FIXME(woosuk): This may not be true for all models.
return (
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
)
return self.get_hidden_size() // total_num_attention_heads
def get_total_num_kv_heads(self) -> int:
attributes = [
@@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase:
]
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
default_factory = lambda: self.hf_text_config.num_attention_heads
default_factory = self.get_total_num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)