[ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
TJian
2026-03-03 22:24:21 +08:00
committed by GitHub
parent ea463978bb
commit fb7fdc49c4
10 changed files with 217 additions and 61 deletions

View File

@@ -5,6 +5,7 @@ from collections.abc import Callable
import pytest
from vllm.config import PassConfig
from vllm.platforms import current_platform
from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported
from .common import (
@@ -16,6 +17,8 @@ from .common import (
)
from .models import (
FLASHINFER_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
llama3_8b_fp4,
llama3_8b_fp8,
@@ -29,12 +32,33 @@ from .models import (
"model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm",
[
(*llama3_8b_fp8, False),
(*llama4_scout_fp8, False),
(*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
pytest.param(
*llama4_scout_fp8,
False,
marks=pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Llama4 Scout FP8 only supported on CUDA",
),
),
pytest.param(
*qwen3_a3b_fp8,
True,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
),
),
],
)
@pytest.mark.parametrize(
"attn_backend",
[
TRITON_ATTN,
FLASHINFER_ATTN,
ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN,
],
)
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@@ -81,6 +105,8 @@ def test_tp1_fp8_fusions(
),
)
use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower())
matches_check = [
"rms_quant_fusion",
"act_quant_fusion",
@@ -88,6 +114,15 @@ def test_tp1_fp8_fusions(
"attn_quant_fusion",
]
if use_aiter:
matches_check[0] = "aiter_rms_quant_fusion"
matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion)
# TODO: enable the `norm_rope_fusion` test,
# On ROCm norm_rope_fusion is only supported without
# enabling AITER.
matches_check.remove("norm_rope_fusion")
run_e2e_fusion_test(
model_name,
matches,
@@ -96,6 +131,7 @@ def test_tp1_fp8_fusions(
compilation_config,
matches_check,
use_deepgemm=use_deepgemm,
use_aiter=use_aiter,
)