[ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from collections.abc import Callable
|
||||
import pytest
|
||||
|
||||
from vllm.config import PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported
|
||||
|
||||
from .common import (
|
||||
@@ -16,6 +17,8 @@ from .common import (
|
||||
)
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
ROCM_ATTN,
|
||||
TRITON_ATTN,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
@@ -29,12 +32,33 @@ from .models import (
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm",
|
||||
[
|
||||
(*llama3_8b_fp8, False),
|
||||
(*llama4_scout_fp8, False),
|
||||
(*qwen3_a3b_fp8, False),
|
||||
(*qwen3_a3b_fp8, True),
|
||||
pytest.param(
|
||||
*llama4_scout_fp8,
|
||||
False,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="Llama4 Scout FP8 only supported on CUDA",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
*qwen3_a3b_fp8,
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[
|
||||
TRITON_ATTN,
|
||||
FLASHINFER_ATTN,
|
||||
ROCM_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [6])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
@@ -81,6 +105,8 @@ def test_tp1_fp8_fusions(
|
||||
),
|
||||
)
|
||||
|
||||
use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower())
|
||||
|
||||
matches_check = [
|
||||
"rms_quant_fusion",
|
||||
"act_quant_fusion",
|
||||
@@ -88,6 +114,15 @@ def test_tp1_fp8_fusions(
|
||||
"attn_quant_fusion",
|
||||
]
|
||||
|
||||
if use_aiter:
|
||||
matches_check[0] = "aiter_rms_quant_fusion"
|
||||
|
||||
matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion)
|
||||
# TODO: enable the `norm_rope_fusion` test,
|
||||
# On ROCm norm_rope_fusion is only supported without
|
||||
# enabling AITER.
|
||||
matches_check.remove("norm_rope_fusion")
|
||||
|
||||
run_e2e_fusion_test(
|
||||
model_name,
|
||||
matches,
|
||||
@@ -96,6 +131,7 @@ def test_tp1_fp8_fusions(
|
||||
compilation_config,
|
||||
matches_check,
|
||||
use_deepgemm=use_deepgemm,
|
||||
use_aiter=use_aiter,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user