[Misc] Allow AutoWeightsLoader to skip loading weights with specific substr in name (#18358)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-05-20 11:20:12 +08:00
committed by GitHub
parent d565e0976f
commit f07a673eb2
18 changed files with 116 additions and 109 deletions

View File

@@ -77,3 +77,73 @@ def test_module_with_child_containing_batchnorm_can_autoload():
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"prefix.bn.weight": torch.Tensor([1, 2]),
"prefix.bn.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
# Run some data through the module with batchnorm
mod(torch.Tensor([[1, 2], [3, 4]]))
# Try to load the weights to a new instance
def weight_generator():
# weights needed to be filtered out
redundant_weights = {
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
"nested_mod.substr.weight": torch.Tensor([1, 2]),
"nested_mod.substr.bias": torch.Tensor([3, 4]),
}
yield from (mod.state_dict() | redundant_weights).items()
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
loader.load_weights(weight_generator())
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1