[ROCm][Bugfix] Fix Aiter RMSNorm (#23412)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||
from vllm.model_executor.layers.layernorm import (RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm, rms_norm)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
# Registered subclass for test
|
||||
@CustomOp.register("relu3")
|
||||
@@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="AITER is a feature exclusive for ROCm")
|
||||
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
|
||||
use_rocm_aiter_norm: str, monkeypatch):
|
||||
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
|
||||
use_rocm_aiter: str, use_rocm_aiter_norm: str,
|
||||
monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
|
||||
|
||||
if not add_residual:
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_norm):
|
||||
assert rms_norm_func == rocm_aiter_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_norm):
|
||||
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
|
||||
else:
|
||||
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
|
||||
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
elif should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
|
||||
elif add_residual:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
|
||||
Reference in New Issue
Block a user