[Bugfix] Fix Crashing When Loading Modules With Batchnorm Stats (#15813)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -158,6 +158,26 @@ class AutoWeightsLoader:
|
||||
|
||||
yield weight_qualname
|
||||
|
||||
def _add_loadable_non_param_tensors(self, module: nn.Module,
|
||||
child_params: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Add tensor names that are not in the model params that may be in the
|
||||
safetensors, e.g., batch normalization stats.
|
||||
"""
|
||||
if isinstance(module, (
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.LazyBatchNorm1d,
|
||||
nn.LazyBatchNorm2d,
|
||||
nn.LazyBatchNorm3d,
|
||||
nn.SyncBatchNorm,
|
||||
)):
|
||||
module_state_dict = module.state_dict()
|
||||
for stat_name in ("running_mean", "running_var",
|
||||
"num_batches_tracked"):
|
||||
child_params[stat_name] = module_state_dict[stat_name]
|
||||
|
||||
def _load_module(
|
||||
self,
|
||||
base_prefix: str,
|
||||
@@ -186,6 +206,10 @@ class AutoWeightsLoader:
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
|
||||
# Add missing tensors the weight loader needs to be able to load
|
||||
# that aren't registered as params, e.g., batchnorm statistics.
|
||||
self._add_loadable_non_param_tensors(module, child_params)
|
||||
|
||||
for child_prefix, child_weights in self._groupby_prefix(weights):
|
||||
prefix = self._get_qualname(base_prefix, child_prefix)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user