Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,16 +7,24 @@ import torch
|
||||
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.activation import (
|
||||
GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
dispatch_topk_func,
|
||||
vllm_topk_softmax,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.layernorm import (RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm, rms_norm)
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm,
|
||||
rms_norm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
@@ -65,14 +73,21 @@ class Relu3(ReLUSquaredActivation):
|
||||
("-silu_and_mul,+relu3", 3, True, [0, 0, 0, 1], False),
|
||||
# All but RMSNorm
|
||||
("all,-rms_norm", 4, True, [0, 1, 1, 1], True),
|
||||
])
|
||||
def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
|
||||
ops_enabled: list[int], default_on: bool):
|
||||
custom_ops = env.split(',') if env else []
|
||||
],
|
||||
)
|
||||
def test_enabled_ops(
|
||||
env: Optional[str],
|
||||
torch_level: int,
|
||||
use_inductor: bool,
|
||||
ops_enabled: list[int],
|
||||
default_on: bool,
|
||||
):
|
||||
custom_ops = env.split(",") if env else []
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(use_inductor=bool(use_inductor),
|
||||
level=torch_level,
|
||||
custom_ops=custom_ops))
|
||||
compilation_config=CompilationConfig(
|
||||
use_inductor=bool(use_inductor), level=torch_level, custom_ops=custom_ops
|
||||
)
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert CustomOp.default_on() == default_on
|
||||
|
||||
@@ -100,11 +115,13 @@ def test_enabled_ops(env: Optional[str], torch_level: int, use_inductor: bool,
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"])
|
||||
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]
|
||||
)
|
||||
def test_enabled_ops_invalid(env: str):
|
||||
with pytest.raises(Exception): # noqa
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
custom_ops=env.split(",")))
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=env.split(","))
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
@@ -116,28 +133,38 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
is_rocm_aiter_moe_enabled.cache_clear()
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_topk_softmax)
|
||||
rocm_aiter_topk_softmax,
|
||||
)
|
||||
|
||||
assert topk_func == rocm_aiter_topk_softmax
|
||||
else:
|
||||
assert topk_func == vllm_topk_softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="AITER is a feature exclusive for ROCm")
|
||||
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
|
||||
use_rocm_aiter: str, use_rocm_aiter_norm: str,
|
||||
monkeypatch):
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
|
||||
)
|
||||
def test_rms_norm_dispatch(
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
use_rocm_aiter: str,
|
||||
use_rocm_aiter_norm: str,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
|
||||
|
||||
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
|
||||
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
should_use_rocm_aiter = (
|
||||
current_platform.is_rocm()
|
||||
and int(use_rocm_aiter)
|
||||
and int(use_rocm_aiter_norm)
|
||||
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
)
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
|
||||
Reference in New Issue
Block a user