Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,7 +10,6 @@ pytestmark = pytest.mark.cpu_test
class ModuleWithBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm1d(2)
@@ -20,7 +19,6 @@ class ModuleWithBatchNorm(torch.nn.Module):
class ModuleWithNestedBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested_mod = ModuleWithBatchNorm()
@@ -67,9 +65,11 @@ def test_module_with_child_containing_batchnorm_can_autoload():
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod)
@@ -77,9 +77,9 @@ def test_module_with_child_containing_batchnorm_can_autoload():
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@@ -101,9 +101,11 @@ def test_module_skip_prefix():
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
@@ -111,9 +113,9 @@ def test_module_skip_prefix():
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@@ -137,9 +139,11 @@ def test_module_skip_substr():
new_mod = ModuleWithNestedBatchNorm()
assert not torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert not torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
@@ -147,7 +151,7 @@ def test_module_skip_substr():
# Ensure the stats are updated
assert torch.all(
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
assert torch.all(
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1