[ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -182,8 +182,24 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
|
||||
"model_class, enable_quant_fp8_custom_op, force_kernel",
|
||||
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
|
||||
+ [
|
||||
(TestSiluMulNvfp4QuantModel, False, None),
|
||||
(TestSiluMulGroupFp8QuantModel, False, None),
|
||||
pytest.param(
|
||||
TestSiluMulNvfp4QuantModel,
|
||||
False,
|
||||
None,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUDA only"
|
||||
),
|
||||
),
|
||||
# GroupFP8Quant fusion only works with AITER on ROCm.
|
||||
# and the enable_quant_fp8_custom_op must be True.
|
||||
pytest.param(
|
||||
TestSiluMulGroupFp8QuantModel,
|
||||
True,
|
||||
None,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="ROCm only"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
@@ -201,6 +217,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
enable_silu_mul_custom_op: bool,
|
||||
enable_quant_fp8_custom_op: bool,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||
@@ -227,13 +244,16 @@ def test_fusion_silu_and_mul_quant(
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(config):
|
||||
with set_current_vllm_config(config), monkeypatch.context() as m:
|
||||
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||
if IS_AITER_FOUND:
|
||||
if IS_AITER_FOUND and model_class is TestSiluMulGroupFp8QuantModel:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||
|
||||
Reference in New Issue
Block a user