Fix hf_override_fn when it modifies model_type (#35200)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -161,7 +161,16 @@ class HFConfigParser(ConfigParserBase):
|
|||||||
)
|
)
|
||||||
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
|
# Allow hf_overrides to override model_type before checking _CONFIG_REGISTRY
|
||||||
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
|
if (hf_overrides := kwargs.pop("hf_overrides", None)) is not None:
|
||||||
model_type = hf_overrides.get("model_type", model_type)
|
if isinstance(hf_overrides, dict) and "model_type" in hf_overrides:
|
||||||
|
model_type = hf_overrides["model_type"]
|
||||||
|
elif callable(hf_overrides):
|
||||||
|
# If hf_overrides doesn't modify model_type, it will be passed straight
|
||||||
|
# through and remain unchanged by this elif block
|
||||||
|
dummy_model_type = f"dummy_{model_type}"
|
||||||
|
dummy_kwargs = dict(architectures=[""], model_type=dummy_model_type)
|
||||||
|
dummy_config = PretrainedConfig(**dummy_kwargs)
|
||||||
|
dummy_model_type = hf_overrides(dummy_config).model_type
|
||||||
|
model_type = dummy_model_type.removeprefix("dummy_")
|
||||||
|
|
||||||
if model_type in _CONFIG_REGISTRY:
|
if model_type in _CONFIG_REGISTRY:
|
||||||
config_class = _CONFIG_REGISTRY[model_type]
|
config_class = _CONFIG_REGISTRY[model_type]
|
||||||
@@ -634,7 +643,7 @@ def get_config(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
hf_overrides=hf_overrides_kw,
|
hf_overrides=hf_overrides_kw or hf_overrides_fn,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ class ModelArchConfigConvertorBase:
|
|||||||
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
|
if getattr(self.hf_text_config, "hidden_size_per_head", None) is not None:
|
||||||
return self.hf_text_config.hidden_size_per_head
|
return self.hf_text_config.hidden_size_per_head
|
||||||
|
|
||||||
|
if (total_num_attention_heads := self.get_total_num_attention_heads()) == 0:
|
||||||
|
return 0
|
||||||
# FIXME(woosuk): This may not be true for all models.
|
# FIXME(woosuk): This may not be true for all models.
|
||||||
return (
|
return self.get_hidden_size() // total_num_attention_heads
|
||||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_total_num_kv_heads(self) -> int:
|
def get_total_num_kv_heads(self) -> int:
|
||||||
attributes = [
|
attributes = [
|
||||||
@@ -96,7 +96,7 @@ class ModelArchConfigConvertorBase:
|
|||||||
]
|
]
|
||||||
# For non-grouped-query attention models, the number of KV heads is
|
# For non-grouped-query attention models, the number of KV heads is
|
||||||
# equal to the number of attention heads.
|
# equal to the number of attention heads.
|
||||||
default_factory = lambda: self.hf_text_config.num_attention_heads
|
default_factory = self.get_total_num_attention_heads
|
||||||
return getattr_iter(
|
return getattr_iter(
|
||||||
self.hf_text_config, attributes, default_factory=default_factory
|
self.hf_text_config, attributes, default_factory=default_factory
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user