[Bugfix] Fix Crashing When Loading Modules With Batchnorm Stats (#15813)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2025-03-31 07:23:53 -06:00
committed by GitHub
parent 3aa2b6a637
commit c2e7507ad4
2 changed files with 103 additions and 0 deletions

View File

@@ -158,6 +158,26 @@ class AutoWeightsLoader:
yield weight_qualname
def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
"""
if isinstance(module, (
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LazyBatchNorm1d,
nn.LazyBatchNorm2d,
nn.LazyBatchNorm3d,
nn.SyncBatchNorm,
)):
module_state_dict = module.state_dict()
for stat_name in ("running_mean", "running_var",
"num_batches_tracked"):
child_params[stat_name] = module_state_dict[stat_name]
def _load_module(
self,
base_prefix: str,
@@ -186,6 +206,10 @@ class AutoWeightsLoader:
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
# Add missing tensors the weight loader needs to be able to load
# that aren't registered as params, e.g., batchnorm statistics.
self._add_loadable_non_param_tensors(module, child_params)
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)