Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,16 +7,24 @@ import torch
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation,
SiluAndMul)
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
vllm_topk_softmax)
from vllm.model_executor.layers.activation import (
GeluAndMul,
ReLUSquaredActivation,
SiluAndMul,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_topk_func,
vllm_topk_softmax,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
is_rocm_aiter_moe_enabled,
)
from vllm.model_executor.layers.layernorm import (
RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm,
rms_norm,
)
from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@@ -65,14 +73,21 @@ class Relu3(ReLUSquaredActivation):
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
# All but RMSNorm
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
])
def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
ops_enabled: list[int], default_on: bool):
custom_ops = env.split(',') if env else []
],
)
def test_enabled_ops(
env: Optional[str],
torch_level: int,
use_inductor: bool,
ops_enabled: list[int],
default_on: bool,
):
custom_ops = env.split(",") if env else []
vllm_config = VllmConfig(
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
level=torch_level,
custom_ops=custom_ops))
compilation_config=CompilationConfig(
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
)
)
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on
@@ -100,11 +115,13 @@ def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
@pytest.mark.parametrize(
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]
)
def test_enabled_ops_invalid(env: str):
with pytest.raises(Exception): # noqa
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=env.split(","))
)
with set_current_vllm_config(vllm_config):
RMSNorm(1024).enabled()
@@ -116,28 +133,38 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
is_rocm_aiter_moe_enabled.cache_clear()
if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax)
rocm_aiter_topk_softmax,
)
assert topk_func == rocm_aiter_topk_softmax
else:
assert topk_func == vllm_topk_softmax
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
use_rocm_aiter: str, use_rocm_aiter_norm: str,
monkeypatch):
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
)
def test_rms_norm_dispatch(
add_residual: bool,
dtype: torch.dtype,
use_rocm_aiter: str,
use_rocm_aiter_norm: str,
monkeypatch,
):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
should_use_rocm_aiter = (
current_platform.is_rocm()
and int(use_rocm_aiter)
and int(use_rocm_aiter_norm)
and dtype in RMS_NORM_SUPPORTED_DTYPES
)
if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add