[UX] Support nested dicts in hf_overrides (#25727)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -367,6 +367,51 @@ class ModelConfig:
|
||||
assert_hashable(str_factors)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def _update_nested(
|
||||
self,
|
||||
target: Union["PretrainedConfig", dict[str, Any]],
|
||||
updates: dict[str, Any],
|
||||
) -> None:
|
||||
"""Recursively updates a config or dict with nested updates."""
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict):
|
||||
# Get the nested target
|
||||
if isinstance(target, dict):
|
||||
nested_target = target.get(key)
|
||||
else:
|
||||
nested_target = getattr(target, key, None)
|
||||
|
||||
# If nested target exists and can be updated recursively
|
||||
if nested_target is not None and (
|
||||
isinstance(nested_target, dict)
|
||||
or hasattr(nested_target, "__dict__")
|
||||
):
|
||||
self._update_nested(nested_target, value)
|
||||
continue
|
||||
|
||||
# Set the value (base case)
|
||||
if isinstance(target, dict):
|
||||
target[key] = value
|
||||
else:
|
||||
setattr(target, key, value)
|
||||
|
||||
def _apply_dict_overrides(
|
||||
self,
|
||||
config: "PretrainedConfig",
|
||||
overrides: dict[str, Any],
|
||||
) -> None:
|
||||
"""Apply dict overrides, handling both nested configs and dict values."""
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
for key, value in overrides.items():
|
||||
attr = getattr(config, key, None)
|
||||
if attr is not None and isinstance(attr, PretrainedConfig):
|
||||
# It's a nested config - recursively update it
|
||||
self._update_nested(attr, value)
|
||||
else:
|
||||
# It's a dict-valued parameter - set it directly
|
||||
setattr(config, key, value)
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
# Multimodal config init vars
|
||||
@@ -419,8 +464,17 @@ class ModelConfig:
|
||||
if callable(self.hf_overrides):
|
||||
hf_overrides_kw = {}
|
||||
hf_overrides_fn = self.hf_overrides
|
||||
dict_overrides: dict[str, Any] = {}
|
||||
else:
|
||||
hf_overrides_kw = self.hf_overrides
|
||||
# Separate dict overrides from flat ones
|
||||
# We'll determine how to apply dict overrides after loading the config
|
||||
hf_overrides_kw = {}
|
||||
dict_overrides = {}
|
||||
for key, value in self.hf_overrides.items():
|
||||
if isinstance(value, dict):
|
||||
dict_overrides[key] = value
|
||||
else:
|
||||
hf_overrides_kw[key] = value
|
||||
hf_overrides_fn = None
|
||||
|
||||
if self.rope_scaling:
|
||||
@@ -478,6 +532,8 @@ class ModelConfig:
|
||||
)
|
||||
|
||||
self.hf_config = hf_config
|
||||
if dict_overrides:
|
||||
self._apply_dict_overrides(hf_config, dict_overrides)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
self.attention_chunk_size = getattr(
|
||||
self.hf_text_config, "attention_chunk_size", None
|
||||
|
||||
Reference in New Issue
Block a user