[UX] Support nested dicts in hf_overrides (#25727)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-10-06 23:19:16 -04:00
committed by GitHub
parent 2111b4643c
commit c6873c4e6d
2 changed files with 88 additions and 1 deletions

View File

@@ -367,6 +367,51 @@ class ModelConfig:
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
def _update_nested(
self,
target: Union["PretrainedConfig", dict[str, Any]],
updates: dict[str, Any],
) -> None:
"""Recursively updates a config or dict with nested updates."""
for key, value in updates.items():
if isinstance(value, dict):
# Get the nested target
if isinstance(target, dict):
nested_target = target.get(key)
else:
nested_target = getattr(target, key, None)
# If nested target exists and can be updated recursively
if nested_target is not None and (
isinstance(nested_target, dict)
or hasattr(nested_target, "__dict__")
):
self._update_nested(nested_target, value)
continue
# Set the value (base case)
if isinstance(target, dict):
target[key] = value
else:
setattr(target, key, value)
def _apply_dict_overrides(
self,
config: "PretrainedConfig",
overrides: dict[str, Any],
) -> None:
"""Apply dict overrides, handling both nested configs and dict values."""
from transformers import PretrainedConfig
for key, value in overrides.items():
attr = getattr(config, key, None)
if attr is not None and isinstance(attr, PretrainedConfig):
# It's a nested config - recursively update it
self._update_nested(attr, value)
else:
# It's a dict-valued parameter - set it directly
setattr(config, key, value)
def __post_init__(
self,
# Multimodal config init vars
@@ -419,8 +464,17 @@ class ModelConfig:
if callable(self.hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = self.hf_overrides
dict_overrides: dict[str, Any] = {}
else:
hf_overrides_kw = self.hf_overrides
# Separate dict overrides from flat ones
# We'll determine how to apply dict overrides after loading the config
hf_overrides_kw = {}
dict_overrides = {}
for key, value in self.hf_overrides.items():
if isinstance(value, dict):
dict_overrides[key] = value
else:
hf_overrides_kw[key] = value
hf_overrides_fn = None
if self.rope_scaling:
@@ -478,6 +532,8 @@ class ModelConfig:
)
self.hf_config = hf_config
if dict_overrides:
self._apply_dict_overrides(hf_config, dict_overrides)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(
self.hf_text_config, "attention_chunk_size", None