[ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -610,6 +610,8 @@ steps:
|
||||
--ignore=lora/test_qwen3moe_tp.py
|
||||
parallelism: 4
|
||||
|
||||
##### .buildkite/test_areas/pytorch.yaml #####
|
||||
# corresponds to .buildkite/test_areas/pytorch.yaml
|
||||
- label: PyTorch Compilation Unit Tests # 15min
|
||||
timeout_in_minutes: 30
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
@@ -627,6 +629,20 @@ steps:
|
||||
# they do not suffer from https://github.com/vllm-project/vllm/issues/28965
|
||||
- "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
|
||||
|
||||
# corresponds to .buildkite/test_areas/pytorch.yaml
|
||||
- label: PyTorch Compilation Passes Unit Tests
|
||||
timeout_in_minutes: 20
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile/passes
|
||||
commands:
|
||||
# TODO: clean up this comment if not needed. It is used to
|
||||
# keep track of the tests changes during vLLM IR Ops refactoring.
|
||||
# Use `find` to launch multiple instances of pytest.
|
||||
- "find compile/passes -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;"
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
@@ -1211,41 +1227,6 @@ steps:
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
|
||||
|
||||
- label: Blackwell Fusion and Compile Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/fp4/
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/v1/worker/
|
||||
- vllm/v1/cudagraph_dispatcher.py
|
||||
- vllm/compilation/
|
||||
# can affect pattern matching
|
||||
- vllm/model_executor/layers/layernorm.py
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/passes/test_fusion_attn.py
|
||||
- tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
- tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
- tests/compile/fullgraph/test_full_graph.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/passes/test_fusion_attn.py
|
||||
- pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
|
||||
# # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||
# # Wrap with quotes to escape yaml
|
||||
# - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
|
||||
# Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293
|
||||
# in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated.
|
||||
|
||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||
- pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
@@ -1371,7 +1352,6 @@ steps:
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- pytest -v -s compile/correctness_e2e/test_sequence_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s v1/worker/test_worker_memory_snapshot.py
|
||||
|
||||
@@ -1601,16 +1581,16 @@ steps:
|
||||
commands:
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
# TODO: this test is not supported on ROCm, there are aiter kernels for this.
|
||||
# - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
#- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||
# - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'"
|
||||
# Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293
|
||||
# in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated.
|
||||
|
||||
- VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
# this test is not supported on ROCm
|
||||
# - pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
|
||||
##### B200 test #####
|
||||
- label: Distributed Tests (B200) # optional
|
||||
@@ -1721,6 +1701,93 @@ steps:
|
||||
commands:
|
||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
||||
|
||||
##### .buildkite/test_areas/compile.yaml #####
|
||||
# Slowly setting up the tests so that it is also easier for the
|
||||
# CI team to review and upstream to the pipelinev2.
|
||||
# The following tests are important for vLLM IR Ops refactoring,
|
||||
# which affects fusion passes on ROCm. So we have to
|
||||
# enable them as as soon as possible.
|
||||
|
||||
## TODO: Enable the test in this group
|
||||
# # corresponds to .buildkite/test_areas/compile.yaml
|
||||
# - label: Fusion and Compile Unit Tests (2xMI325 GPUs)
|
||||
# timeout_in_minutes: 20
|
||||
# working_dir: "/vllm-workspace/"
|
||||
# mirror_hardwares: [amdexperimental, amdproduction, tj]
|
||||
# agent_pool: mi325_1 # changed to 1 GPU until the fusion all reduce is enabled then only revert back to 2 GPUs
|
||||
# source_file_dependencies:
|
||||
# - csrc/quantization/fp4/
|
||||
# - vllm/model_executor/layers/quantization/
|
||||
# - vllm/model_executor/layers/layernorm.py
|
||||
# - vllm/model_executor/layers/activation.py
|
||||
# - vllm/model_executor/layers/attention/attention.py
|
||||
# - vllm/v1/attention/backends/flashinfer.py
|
||||
# - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes
|
||||
# - tests/compile/test_fusion_attn.py
|
||||
# - tests/compile/test_silu_mul_quant_fusion.py
|
||||
# - tests/compile/distributed/test_fusion_all_reduce.py
|
||||
# - tests/compile/fullgraph/test_full_graph.py
|
||||
# commands:
|
||||
# - rocm-smi
|
||||
# # we run all backend tests on ROCm
|
||||
# # These two tests are covered in "PyTorch Compilation Passes Unit Tests"
|
||||
# # - "pytest -v -s tests/compile/passes/test_fusion_attn.py"
|
||||
# # - "pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py"
|
||||
# # TODO: this test is not supported on ROCm, there are aiter kernels for this.
|
||||
# # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py
|
||||
# # TODO: find out more details
|
||||
# # - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile
|
||||
|
||||
# corresponds to .buildkite/test_areas/compile.yaml
|
||||
- label: Fusion E2E Quick (MI325)
|
||||
timeout_in_minutes: 15
|
||||
working_dir: "/vllm-workspace/"
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
num_devices: 1
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/
|
||||
- vllm/v1/attention/
|
||||
- vllm/compilation/
|
||||
- tests/compile/fusions_e2e/
|
||||
commands:
|
||||
- rocm-smi
|
||||
# Run all models and attn backends but only Inductor partition and native custom ops
|
||||
- "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and not +rms_norm and not +quant_fp8'"
|
||||
# Different from CUDA, Qwen requires +rms_norm and +quant_fp8 as rms+quant fusion is only supported on AITER
|
||||
- "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and +rms_norm and +quant_fp8 and qwen3'"
|
||||
|
||||
# corresponds to .buildkite/test_areas/compile.yaml
|
||||
- label: Fusion E2E Config Sweep (MI325)
|
||||
timeout_in_minutes: 30
|
||||
working_dir: "/vllm-workspace/"
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
agent_pool: mi325_1
|
||||
num_devices: 1
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/compilation/
|
||||
# can affect pattern matching
|
||||
- vllm/model_executor/layers/layernorm.py
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/attention/attention.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
- tests/compile/fusions_e2e/
|
||||
commands:
|
||||
- rocm-smi
|
||||
# Run just llama3 (fp8) for all config combinations
|
||||
- pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "llama-3"
|
||||
|
||||
## There are no ops on ROCm for these tests.
|
||||
## The test still passes but the logs are not useful.
|
||||
## fused ops just call torch.ops.symm_mem which
|
||||
## exists in ROCm even though they don't work
|
||||
# - label: AsyncTP Correctness Tests (2xMI325 GPUs)
|
||||
# - label: Fusion E2E TP2 Quick (MI325)
|
||||
# - label: Fusion E2E TP2 AsyncTP Config Sweep (MI325)
|
||||
# - label: Fusion E2E TP2 (MI325)
|
||||
# - label: Sequence Parallel Correctness Tests (2xMI325 GPUs)
|
||||
|
||||
|
||||
#####################################################################################################################################
|
||||
|
||||
@@ -13,6 +13,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
class Matches(NamedTuple):
|
||||
# simple pointwise
|
||||
aiter_rms_quant_fusion: int = 0
|
||||
rms_quant_fusion: int = 0
|
||||
act_quant_fusion: int = 0
|
||||
norm_rope_fusion: int = 0
|
||||
@@ -82,6 +83,9 @@ INDUCTOR_GRAPH_PARTITION = [
|
||||
]
|
||||
|
||||
FUSION_LOG_PATTERNS: dict[str, re.Pattern] = {
|
||||
"aiter_rms_quant_fusion": re.compile(
|
||||
r"RocmAiterRMSNormQuantFusionPass Replaced (\d+) patterns"
|
||||
),
|
||||
"rms_quant_fusion": re.compile(r"rms_quant_fusion.py:\d+] Replaced (\d+) patterns"),
|
||||
"act_quant_fusion": re.compile(r"act_quant_fusion.py:\d+] Replaced (\d+) patterns"),
|
||||
"norm_rope_fusion": re.compile(
|
||||
|
||||
@@ -63,9 +63,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
compilation_config: dict,
|
||||
matches_check: list[str],
|
||||
use_deepgemm: bool = False,
|
||||
use_aiter: bool = False,
|
||||
tp_size: int = 1,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1" if use_deepgemm else "0")
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_aiter else "0")
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
@@ -24,6 +26,24 @@ TRITON_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.TRITON_ATTN), id="TRITON_ATTN"
|
||||
)
|
||||
|
||||
ROCM_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.ROCM_ATTN),
|
||||
id="ROCM_ATTN",
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_rocm(),
|
||||
reason="ROCm attention only for AMD",
|
||||
),
|
||||
)
|
||||
|
||||
ROCM_AITER_UNIFIED_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN),
|
||||
id="ROCM_AITER_UNIFIED_ATTN",
|
||||
marks=pytest.mark.skipif(
|
||||
not is_aiter_found_and_supported(),
|
||||
reason="ROCM_AITER_UNIFIED_ATTN only for AMD when AITER is installed",
|
||||
),
|
||||
)
|
||||
|
||||
# Models
|
||||
llama3_8b = ModelFusionInfo(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
@@ -49,7 +69,6 @@ llama3_8b_fp8 = ModelFusionInfo(
|
||||
llama3_8b_fp4 = ModelFusionInfo(
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
matches=lambda n_layers: Matches(
|
||||
rms_quant_fusion=0,
|
||||
act_quant_fusion=n_layers,
|
||||
attn_quant_fusion=n_layers,
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
@@ -79,7 +98,6 @@ llama4_scout_fp4 = ModelFusionInfo(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-NVFP4",
|
||||
hf_overrides=lambda n_layers: {"text_config": {"num_hidden_layers": n_layers}},
|
||||
matches=lambda n_layers: Matches(
|
||||
rms_quant_fusion=0,
|
||||
attn_quant_fusion=n_layers,
|
||||
ar_rms_fusion=n_layers * 2,
|
||||
sequence_parallel=n_layers * 2,
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Callable
|
||||
import pytest
|
||||
|
||||
from vllm.config import PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported
|
||||
|
||||
from .common import (
|
||||
@@ -16,6 +17,8 @@ from .common import (
|
||||
)
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
ROCM_ATTN,
|
||||
TRITON_ATTN,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
@@ -29,12 +32,33 @@ from .models import (
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm",
|
||||
[
|
||||
(*llama3_8b_fp8, False),
|
||||
(*llama4_scout_fp8, False),
|
||||
(*qwen3_a3b_fp8, False),
|
||||
(*qwen3_a3b_fp8, True),
|
||||
pytest.param(
|
||||
*llama4_scout_fp8,
|
||||
False,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="Llama4 Scout FP8 only supported on CUDA",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
*qwen3_a3b_fp8,
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[
|
||||
TRITON_ATTN,
|
||||
FLASHINFER_ATTN,
|
||||
ROCM_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [6])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
@@ -81,6 +105,8 @@ def test_tp1_fp8_fusions(
|
||||
),
|
||||
)
|
||||
|
||||
use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower())
|
||||
|
||||
matches_check = [
|
||||
"rms_quant_fusion",
|
||||
"act_quant_fusion",
|
||||
@@ -88,6 +114,15 @@ def test_tp1_fp8_fusions(
|
||||
"attn_quant_fusion",
|
||||
]
|
||||
|
||||
if use_aiter:
|
||||
matches_check[0] = "aiter_rms_quant_fusion"
|
||||
|
||||
matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion)
|
||||
# TODO: enable the `norm_rope_fusion` test,
|
||||
# On ROCm norm_rope_fusion is only supported without
|
||||
# enabling AITER.
|
||||
matches_check.remove("norm_rope_fusion")
|
||||
|
||||
run_e2e_fusion_test(
|
||||
model_name,
|
||||
matches,
|
||||
@@ -96,6 +131,7 @@ def test_tp1_fp8_fusions(
|
||||
compilation_config,
|
||||
matches_check,
|
||||
use_deepgemm=use_deepgemm,
|
||||
use_aiter=use_aiter,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Callable
|
||||
import pytest
|
||||
|
||||
from vllm.config import PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .common import (
|
||||
@@ -26,6 +27,8 @@ from .models import (
|
||||
qwen3_a3b_fp8,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Callable
|
||||
import pytest
|
||||
|
||||
from vllm.config import PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .common import (
|
||||
@@ -23,6 +24,8 @@ from .models import (
|
||||
qwen3_a3b,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -36,6 +36,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
@@ -182,8 +182,24 @@ TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
|
||||
"model_class, enable_quant_fp8_custom_op, force_kernel",
|
||||
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
|
||||
+ [
|
||||
(TestSiluMulNvfp4QuantModel, False, None),
|
||||
(TestSiluMulGroupFp8QuantModel, False, None),
|
||||
pytest.param(
|
||||
TestSiluMulNvfp4QuantModel,
|
||||
False,
|
||||
None,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="CUDA only"
|
||||
),
|
||||
),
|
||||
# GroupFP8Quant fusion only works with AITER on ROCm.
|
||||
# and the enable_quant_fp8_custom_op must be True.
|
||||
pytest.param(
|
||||
TestSiluMulGroupFp8QuantModel,
|
||||
True,
|
||||
None,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="ROCm only"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
@@ -201,6 +217,7 @@ def test_fusion_silu_and_mul_quant(
|
||||
enable_silu_mul_custom_op: bool,
|
||||
enable_quant_fp8_custom_op: bool,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||
@@ -227,13 +244,16 @@ def test_fusion_silu_and_mul_quant(
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(config):
|
||||
with set_current_vllm_config(config), monkeypatch.context() as m:
|
||||
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||
if IS_AITER_FOUND:
|
||||
if IS_AITER_FOUND and model_class is TestSiluMulGroupFp8QuantModel:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion.rocm_aiter_fusion import (
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic128Sym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(
|
||||
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
|
||||
)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
at2 = self.quant_matcher(at1)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user