[FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature (#14959)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||
ReLUSquaredActivation,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# Registered subclass for test
|
||||
@@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str):
|
||||
custom_ops=env.split(",")))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
RMSNorm(1024).enabled()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="AITER is a feature exclusive for ROCm")
|
||||
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
|
||||
use_rocm_aiter_norm: str, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
||||
|
||||
if not add_residual:
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_norm):
|
||||
assert rms_norm_func == rocm_aiter_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_norm):
|
||||
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
|
||||
Reference in New Issue
Block a user