Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,6 @@ pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class ModuleWithBatchNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bn = torch.nn.BatchNorm1d(2)
|
||||
@@ -20,7 +19,6 @@ class ModuleWithBatchNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class ModuleWithNestedBatchNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.nested_mod = ModuleWithBatchNorm()
|
||||
@@ -67,9 +65,11 @@ def test_module_with_child_containing_batchnorm_can_autoload():
|
||||
new_mod = ModuleWithNestedBatchNorm()
|
||||
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
||||
)
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
|
||||
)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod)
|
||||
@@ -77,9 +77,9 @@ def test_module_with_child_containing_batchnorm_can_autoload():
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
||||
)
|
||||
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
@@ -101,9 +101,11 @@ def test_module_skip_prefix():
|
||||
new_mod = ModuleWithNestedBatchNorm()
|
||||
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
||||
)
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
|
||||
)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
|
||||
@@ -111,9 +113,9 @@ def test_module_skip_prefix():
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
||||
)
|
||||
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
@@ -137,9 +139,11 @@ def test_module_skip_substr():
|
||||
new_mod = ModuleWithNestedBatchNorm()
|
||||
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
||||
)
|
||||
assert not torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
|
||||
)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||||
|
||||
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
|
||||
@@ -147,7 +151,7 @@ def test_module_skip_substr():
|
||||
|
||||
# Ensure the stats are updated
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||||
assert torch.all(
|
||||
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
||||
)
|
||||
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
Reference in New Issue
Block a user