Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -24,13 +24,14 @@ aiter_available = importlib.util.find_spec("aiter") is not None
pytestmark = pytest.mark.skipif(
not (current_platform.is_rocm() and aiter_available),
reason="AITER ops are only available on ROCm with aiter package installed")
reason="AITER ops are only available on ROCm with aiter package installed",
)
def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk')
assert hasattr(torch.ops.vllm, "rocm_aiter_biased_grouped_topk")
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk)
@@ -39,7 +40,7 @@ def test_rocm_aiter_biased_grouped_topk_custom_op_registration():
def test_rocm_aiter_grouped_topk_custom_op_registration():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert hasattr(torch.ops.vllm, 'rocm_aiter_grouped_topk')
assert hasattr(torch.ops.vllm, "rocm_aiter_grouped_topk")
# Check if the op is callable
assert callable(torch.ops.vllm.rocm_aiter_grouped_topk)
@@ -56,25 +57,29 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
renormalize = True
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
e_score_correction_bias = torch.randn((expert, ),
dtype=torch.bfloat16,
device="cuda")
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
e_score_correction_bias = torch.randn(
(expert,), dtype=torch.bfloat16, device="cuda"
)
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
# Define a function that uses the op
def biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights, topk_ids):
def biased_grouped_topk_fn(
gating_output, e_score_correction_bias, topk_weights, topk_ids
):
return torch.ops.vllm.rocm_aiter_biased_grouped_topk(
gating_output, e_score_correction_bias, topk_weights, topk_ids,
num_expert_group, topk_group, renormalize, scale_factor)
gating_output,
e_score_correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
scale_factor,
)
# Verify the op's fake implementation
torch.library.opcheck(
@@ -84,51 +89,49 @@ def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility():
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"routed_scaling_factor": scale_factor
"routed_scaling_factor": scale_factor,
},
test_utils=("test_faketensor"))
test_utils=("test_faketensor"),
)
# Compile the function with appropriate settings
compiled_fn = torch.compile(biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_fn = torch.compile(
biased_grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_original = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
biased_grouped_topk_fn(gating_output, e_score_correction_bias,
topk_weights_original, topk_ids_original)
compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled,
topk_ids_compiled)
biased_grouped_topk_fn(
gating_output, e_score_correction_bias, topk_weights_original, topk_ids_original
)
compiled_fn(
gating_output, e_score_correction_bias, topk_weights_compiled, topk_ids_compiled
)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
)
assert torch.allclose(topk_ids_original, topk_ids_compiled)
@@ -144,73 +147,73 @@ def test_rocm_aiter_grouped_topk_torch_compile_compatibility():
scoring_func = "softmax"
scale_factor = 1.0
gating_output = torch.randn((token, expert),
dtype=torch.bfloat16,
device="cuda")
gating_output = torch.randn((token, expert), dtype=torch.bfloat16, device="cuda")
device = gating_output.device
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
# Define a function that uses the op
def grouped_topk_fn(gating_output, topk_weights, topk_ids, scoring_func):
return torch.ops.vllm.rocm_aiter_grouped_topk(
gating_output, topk_weights, topk_ids, num_expert_group,
topk_group, renormalize, scoring_func, scale_factor)
gating_output,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
scoring_func,
scale_factor,
)
# Verify the op's fake implementation
torch.library.opcheck(torch.ops.vllm.rocm_aiter_grouped_topk,
(gating_output, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"scoring_func": scoring_func,
"routed_scaling_factor": scale_factor
},
test_utils=("test_faketensor"))
torch.library.opcheck(
torch.ops.vllm.rocm_aiter_grouped_topk,
(gating_output, topk_weights, topk_ids),
kwargs={
"num_expert_group": num_expert_group,
"topk_group": topk_group,
"need_renorm": renormalize,
"scoring_func": scoring_func,
"routed_scaling_factor": scale_factor,
},
test_utils=("test_faketensor"),
)
# Compile the function with appropriate settings
compiled_fn = torch.compile(grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False)
compiled_fn = torch.compile(
grouped_topk_fn,
fullgraph=True,
backend="inductor",
mode="reduce-overhead",
dynamic=False,
)
topk_weights_original = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_original = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_original = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_original = torch.empty((token, topk), dtype=torch.int32, device=device)
topk_weights_compiled = torch.empty((token, topk),
dtype=torch.float32,
device=device)
topk_ids_compiled = torch.empty((token, topk),
dtype=torch.int32,
device=device)
topk_weights_compiled = torch.empty(
(token, topk), dtype=torch.float32, device=device
)
topk_ids_compiled = torch.empty((token, topk), dtype=torch.int32, device=device)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
grouped_topk_fn(gating_output, topk_weights_original, topk_ids_original,
scoring_func)
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled,
scoring_func)
grouped_topk_fn(
gating_output, topk_weights_original, topk_ids_original, scoring_func
)
compiled_fn(gating_output, topk_weights_compiled, topk_ids_compiled, scoring_func)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original, indices_original = torch.sort(topk_ids_original)
topk_weights_original = torch.gather(topk_weights_original, 1,
indices_original)
topk_weights_original = torch.gather(topk_weights_original, 1, indices_original)
topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1,
indices_compiled)
topk_weights_compiled = torch.gather(topk_weights_compiled, 1, indices_compiled)
# Verify results match
assert torch.allclose(topk_weights_original,
topk_weights_compiled,
rtol=1e-2,
atol=1e-2)
assert torch.allclose(
topk_weights_original, topk_weights_compiled, rtol=1e-2, atol=1e-2
)
assert torch.allclose(topk_ids_original, topk_ids_compiled)