[ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
TJian
2026-03-03 22:24:21 +08:00
committed by GitHub
parent ea463978bb
commit fb7fdc49c4
10 changed files with 217 additions and 61 deletions

View File

@@ -182,8 +182,24 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
"model_class, enable_quant_fp8_custom_op, force_kernel",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
+ [
(TestSiluMulNvfp4QuantModel, False, None),
(TestSiluMulGroupFp8QuantModel, False, None),
pytest.param(
TestSiluMulNvfp4QuantModel,
False,
None,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="CUDA only"
),
),
# GroupFP8Quant fusion only works with AITER on ROCm.
# and the enable_quant_fp8_custom_op must be True.
pytest.param(
TestSiluMulGroupFp8QuantModel,
True,
None,
marks=pytest.mark.skipif(
not current_platform.is_rocm(), reason="ROCm only"
),
),
],
)
@pytest.mark.skipif(
@@ -201,6 +217,7 @@ def test_fusion_silu_and_mul_quant(
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
force_kernel: FP8ScaledMMLinearKernel | None,
monkeypatch: pytest.MonkeyPatch,
):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")
@@ -227,13 +244,16 @@ def test_fusion_silu_and_mul_quant(
),
)
with set_current_vllm_config(config):
with set_current_vllm_config(config), monkeypatch.context() as m:
fusion_passes = [ActivationQuantFusionPass(config)]
if IS_AITER_FOUND:
if IS_AITER_FOUND and model_class is TestSiluMulGroupFp8QuantModel:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
RocmAiterSiluMulFp8GroupQuantFusionPass,
)
m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables()
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]