[ROCm][Bugfix] Fix Aiter RMSNorm (#23412)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-09-10 21:08:03 +08:00
committed by GitHub
parent 0ae43dbf8c
commit 7c195d43da
3 changed files with 108 additions and 36 deletions

View File

@@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# Registered subclass for test
@CustomOp.register("relu3")
@@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
use_rocm_aiter_norm: str, monkeypatch):
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
use_rocm_aiter: str, use_rocm_aiter_norm: str,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
if not add_residual:
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_rms_norm
else:
assert rms_norm_func == rms_norm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
else:
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:
assert rms_norm_func == rms_norm