2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2025-02-02 14:58:18 -05:00
|
|
|
|
2024-10-17 14:36:37 -04:00
|
|
|
import pytest
|
2025-05-14 10:08:20 +08:00
|
|
|
import torch
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2025-11-10 17:20:53 +01:00
|
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
2025-11-27 04:55:58 -05:00
|
|
|
from vllm.config import (
|
|
|
|
|
CompilationConfig,
|
|
|
|
|
VllmConfig,
|
|
|
|
|
get_cached_compilation_config,
|
|
|
|
|
set_current_vllm_config,
|
|
|
|
|
)
|
2026-01-22 00:38:04 +08:00
|
|
|
from vllm.model_executor.custom_op import CustomOp, op_registry
|
2024-10-17 14:36:37 -04:00
|
|
|
from vllm.model_executor.layers.activation import (
|
|
|
|
|
GeluAndMul,
|
|
|
|
|
ReLUSquaredActivation,
|
|
|
|
|
SiluAndMul,
|
|
|
|
|
)
|
2026-01-18 11:40:49 -05:00
|
|
|
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
2026-01-21 14:49:51 -08:00
|
|
|
dispatch_topk_sigmoid_func,
|
|
|
|
|
dispatch_topk_softmax_func,
|
|
|
|
|
vllm_topk_sigmoid,
|
2025-05-14 18:03:11 +08:00
|
|
|
vllm_topk_softmax,
|
|
|
|
|
)
|
2025-09-10 21:08:03 +08:00
|
|
|
from vllm.model_executor.layers.layernorm import (
|
|
|
|
|
RMSNorm,
|
|
|
|
|
dispatch_rocm_rmsnorm_func,
|
|
|
|
|
fused_add_rms_norm,
|
|
|
|
|
)
|
2025-03-22 13:36:14 +08:00
|
|
|
from vllm.platforms import current_platform
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2025-09-10 21:08:03 +08:00
|
|
|
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
|
|
|
|
|
2024-10-17 14:36:37 -04:00
|
|
|
|
|
|
|
|
# Registered subclass for test
|
|
|
|
|
@CustomOp.register("relu3")
|
|
|
|
|
class Relu3(ReLUSquaredActivation):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2025-10-22 10:18:17 +01:00
|
|
|
"env, compilation_mode, backend, ops_enabled, default_on",
|
2024-10-17 14:36:37 -04:00
|
|
|
[
|
|
|
|
|
# Default values based on compile level
|
2025-06-27 11:00:42 -04:00
|
|
|
# - All by default (no Inductor compilation)
|
2025-10-13 18:47:16 -04:00
|
|
|
(None, 0, "eager", [True] * 4, True),
|
|
|
|
|
(None, 1, "eager", [True] * 4, True),
|
|
|
|
|
(None, 2, "eager", [True] * 4, True),
|
|
|
|
|
(None, 3, "eager", [True] * 4, True),
|
2025-06-27 11:00:42 -04:00
|
|
|
# - None by default (with Inductor)
|
2025-10-13 18:47:16 -04:00
|
|
|
(None, 0, "inductor", [True] * 4, True),
|
|
|
|
|
# - None by default (with Inductor)
|
|
|
|
|
(None, 1, "inductor", [False] * 4, False),
|
|
|
|
|
(None, 2, "inductor", [False] * 4, False),
|
|
|
|
|
(None, 3, "inductor", [False] * 4, False),
|
2024-10-17 14:36:37 -04:00
|
|
|
# Explicitly enabling/disabling
|
|
|
|
|
#
|
|
|
|
|
# Default: all
|
|
|
|
|
#
|
|
|
|
|
# All but SiluAndMul
|
2025-10-13 18:47:16 -04:00
|
|
|
("+rms_norm,-silu_and_mul", 0, "inductor", [1, 0, 1, 1], True),
|
2024-10-17 14:36:37 -04:00
|
|
|
# Only ReLU3
|
2025-10-13 18:47:16 -04:00
|
|
|
("none,-rms_norm,+relu3", 1, "eager", [0, 0, 0, 1], False),
|
2024-10-17 14:36:37 -04:00
|
|
|
# All but SiluAndMul
|
2025-10-13 18:47:16 -04:00
|
|
|
("all,-silu_and_mul", 2, "inductor", [1, 0, 1, 1], True),
|
2024-10-17 14:36:37 -04:00
|
|
|
# All but ReLU3 (even if ReLU2 is on)
|
2025-10-13 18:47:16 -04:00
|
|
|
("-relu3,+relu2", 3, "eager", [1, 1, 1, 0], True),
|
2025-06-27 11:00:42 -04:00
|
|
|
# RMSNorm and SiluAndMul
|
2025-10-13 18:47:16 -04:00
|
|
|
("none,-relu3,+rms_norm,+silu_and_mul", 3, "eager", [1, 1, 0, 0], False),
|
2024-10-17 14:36:37 -04:00
|
|
|
# All but RMSNorm
|
2025-10-13 18:47:16 -04:00
|
|
|
("-rms_norm", 3, "eager", [0, 1, 1, 1], True),
|
2024-10-17 14:36:37 -04:00
|
|
|
#
|
|
|
|
|
# Default: none
|
|
|
|
|
#
|
|
|
|
|
# Only ReLU3
|
2025-10-13 18:47:16 -04:00
|
|
|
("none,+relu3", 3, "inductor", [0, 0, 0, 1], False),
|
2024-10-17 14:36:37 -04:00
|
|
|
# All but RMSNorm
|
2025-10-13 18:47:16 -04:00
|
|
|
("all,-rms_norm", 3, "inductor", [0, 1, 1, 1], True),
|
2025-09-22 20:07:43 -04:00
|
|
|
],
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-09-22 20:07:43 -04:00
|
|
|
def test_enabled_ops(
|
|
|
|
|
env: str | None,
|
2025-10-22 10:18:17 +01:00
|
|
|
compilation_mode: int,
|
2025-10-13 18:47:16 -04:00
|
|
|
backend: str,
|
2025-06-27 11:00:42 -04:00
|
|
|
ops_enabled: list[int],
|
|
|
|
|
default_on: bool,
|
|
|
|
|
):
|
2025-09-22 20:07:43 -04:00
|
|
|
custom_ops = env.split(",") if env else []
|
2025-06-27 11:00:42 -04:00
|
|
|
vllm_config = VllmConfig(
|
|
|
|
|
compilation_config=CompilationConfig(
|
2025-10-22 10:18:17 +01:00
|
|
|
backend=backend, mode=compilation_mode, custom_ops=custom_ops
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-09-22 20:07:43 -04:00
|
|
|
)
|
2025-11-27 04:55:58 -05:00
|
|
|
get_cached_compilation_config.cache_clear()
|
2024-11-16 18:02:14 -08:00
|
|
|
with set_current_vllm_config(vllm_config):
|
|
|
|
|
assert CustomOp.default_on() == default_on
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
ops_enabled = [bool(x) for x in ops_enabled]
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
assert RMSNorm(1024).enabled() == ops_enabled[0]
|
2026-01-22 00:38:04 +08:00
|
|
|
assert op_registry["rms_norm"].enabled() == ops_enabled[0]
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
assert SiluAndMul().enabled() == ops_enabled[1]
|
2026-01-22 00:38:04 +08:00
|
|
|
assert op_registry["silu_and_mul"].enabled() == ops_enabled[1]
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
assert GeluAndMul().enabled() == ops_enabled[2]
|
2026-01-22 00:38:04 +08:00
|
|
|
assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
# If registered, subclasses should follow their own name
|
|
|
|
|
assert Relu3().enabled() == ops_enabled[3]
|
2026-01-22 00:38:04 +08:00
|
|
|
assert op_registry["relu3"].enabled() == ops_enabled[3]
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
# Unregistered subclass
|
|
|
|
|
class SiluAndMul2(SiluAndMul):
|
|
|
|
|
pass
|
2024-10-17 14:36:37 -04:00
|
|
|
|
2024-11-16 18:02:14 -08:00
|
|
|
# Subclasses should not require registration
|
|
|
|
|
assert SiluAndMul2().enabled() == SiluAndMul().enabled()
|
2024-10-17 14:36:37 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]
|
|
|
|
|
)
|
|
|
|
|
def test_enabled_ops_invalid(env: str):
|
2024-11-16 18:02:14 -08:00
|
|
|
with pytest.raises(Exception): # noqa
|
|
|
|
|
vllm_config = VllmConfig(
|
|
|
|
|
compilation_config=CompilationConfig(custom_ops=env.split(","))
|
|
|
|
|
)
|
|
|
|
|
with set_current_vllm_config(vllm_config):
|
|
|
|
|
RMSNorm(1024).enabled()
|
2025-03-22 13:36:14 +08:00
|
|
|
|
|
|
|
|
|
2025-11-10 17:20:53 +01:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
|
|
|
|
)
|
2026-01-21 14:49:51 -08:00
|
|
|
def test_topk_softmax_dispatch(use_rocm_aiter: bool):
|
|
|
|
|
topk_func = dispatch_topk_softmax_func(use_rocm_aiter)
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2025-11-10 17:20:53 +01:00
|
|
|
if current_platform.is_rocm() and use_rocm_aiter:
|
|
|
|
|
assert topk_func == rocm_aiter_ops.topk_softmax
|
2025-03-26 16:30:30 +08:00
|
|
|
else:
|
|
|
|
|
assert topk_func == vllm_topk_softmax
|
|
|
|
|
|
|
|
|
|
|
2026-01-21 14:49:51 -08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
|
|
|
|
|
)
|
|
|
|
|
def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
|
|
|
|
topk_func = dispatch_topk_sigmoid_func(use_rocm_aiter)
|
|
|
|
|
|
|
|
|
|
if current_platform.is_rocm() and use_rocm_aiter:
|
|
|
|
|
assert topk_func == rocm_aiter_ops.topk_sigmoid
|
|
|
|
|
else:
|
|
|
|
|
assert topk_func == vllm_topk_sigmoid
|
|
|
|
|
|
|
|
|
|
|
2026-03-31 22:15:05 -04:00
|
|
|
@pytest.mark.parametrize("add_residual", [False])
|
2025-09-10 21:08:03 +08:00
|
|
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
2025-11-10 17:20:53 +01:00
|
|
|
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
2025-03-22 13:36:14 +08:00
|
|
|
@pytest.mark.skipif(
|
|
|
|
|
not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm"
|
|
|
|
|
)
|
2025-09-10 21:08:03 +08:00
|
|
|
def test_rms_norm_dispatch(
|
2025-11-10 17:20:53 +01:00
|
|
|
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
2025-09-10 21:08:03 +08:00
|
|
|
):
|
2026-03-31 22:15:05 -04:00
|
|
|
rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
|
2025-09-10 21:08:03 +08:00
|
|
|
|
|
|
|
|
should_use_rocm_aiter = (
|
|
|
|
|
current_platform.is_rocm()
|
2025-11-10 17:20:53 +01:00
|
|
|
and use_rocm_aiter
|
2025-09-10 21:08:03 +08:00
|
|
|
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-09-10 21:08:03 +08:00
|
|
|
|
2026-03-31 22:15:05 -04:00
|
|
|
if should_use_rocm_aiter:
|
2025-11-10 17:20:53 +01:00
|
|
|
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
2025-09-10 21:08:03 +08:00
|
|
|
else:
|
2026-03-31 22:15:05 -04:00
|
|
|
assert rms_norm_func == fused_add_rms_norm
|