[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm,
|
||||
rms_norm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -156,7 +155,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
||||
assert topk_func == vllm_topk_sigmoid
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("add_residual", [False])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
@@ -165,7 +164,7 @@ def test_topk_sigmoid_dispatch(use_rocm_aiter: bool):
|
||||
def test_rms_norm_dispatch(
|
||||
add_residual: bool, dtype: torch.dtype, use_rocm_aiter: bool
|
||||
):
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype, use_rocm_aiter)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(dtype, use_rocm_aiter)
|
||||
|
||||
should_use_rocm_aiter = (
|
||||
current_platform.is_rocm()
|
||||
@@ -173,11 +172,7 @@ def test_rms_norm_dispatch(
|
||||
and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
)
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
if should_use_rocm_aiter:
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm2d_with_add
|
||||
elif should_use_rocm_aiter:
|
||||
assert rms_norm_func == rocm_aiter_ops.rms_norm
|
||||
elif add_residual:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
|
||||
Reference in New Issue
Block a user