Fix hf_override_fn when it modifies model_type (#35200)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase):
|
||||
)
|
||||
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
|
||||
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
|
||||
model_type = hf_overrides.get("model_type", model_type)
|
||||
if isinstance(hf_overrides, dict) and "model_type" in hf_overrides:
|
||||
model_type = hf_overrides["model_type"]
|
||||
elif callable(hf_overrides):
|
||||
# If hf_overrides doesn't modify model_type, it will be passed straight
|
||||
# through and remain unchanged by this elif block
|
||||
dummy_model_type = f"dummy_{model_type}"
|
||||
dummy_kwargs = dict(architectures=[""], model_type=dummy_model_type)
|
||||
dummy_config = PretrainedConfig(**dummy_kwargs)
|
||||
dummy_model_type = hf_overrides(dummy_config).model_type
|
||||
model_type = dummy_model_type.removeprefix("dummy_")
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
@@ -634,7 +643,7 @@ def get_config(
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
hf_overrides=hf_overrides_kw,
|
||||
hf_overrides=hf_overrides_kw or hf_overrides_fn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase:
|
||||
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
|
||||
return self.hf_text_config.hidden_size_per_head
|
||||
|
||||
if (total_num_attention_heads := self.get_total_num_attention_heads()) == 0:
|
||||
return 0
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return (
|
||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
|
||||
)
|
||||
return self.get_hidden_size() // total_num_attention_heads
|
||||
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
attributes = [
|
||||
@@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase:
|
||||
]
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
default_factory = lambda: self.hf_text_config.num_attention_heads
|
||||
default_factory = self.get_total_num_attention_heads
|
||||
return getattr_iter(
|
||||
self.hf_text_config, attributes, default_factory=default_factory
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user