[CI][torch.compile] Reduce e2e fusion test time (#33293)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: ProExpertProg <luka.govedic@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2026-02-04 19:09:03 -05:00
committed by GitHub
parent 439afa4eea
commit 4d9513537d
17 changed files with 1068 additions and 821 deletions

View File

@@ -1,23 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import logging
from typing import Any
import pytest
import regex as re
import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.compile.fusion_test_utils import (
CUSTOM_OPS_FP8,
MODELS_FP4,
MODELS_FP8,
Matches,
has_cuda_graph_wrapper_metadata,
is_blackwell,
run_model,
)
from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
@@ -31,7 +19,6 @@ from vllm.config import (
CacheConfig,
CompilationConfig,
CompilationMode,
CUDAGraphMode,
ModelConfig,
PassConfig,
SchedulerConfig,
@@ -47,7 +34,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -501,88 +487,3 @@ def test_attention_quant_pattern(
# Check that results are close
torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, matches, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
+ list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize(
"inductor_graph_partition",
[
pytest.param(
True,
marks=pytest.mark.skipif(
not has_cuda_graph_wrapper_metadata(),
reason="This test requires"
"torch._inductor.utils.CUDAGraphWrapperMetadata to run",
),
),
False,
],
)
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: AttentionBackendEnum,
matches: Matches,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if not current_platform.has_device_capability(90):
pytest.skip("test_attn_quant requires H100 (SM90) or B200 (SM100) GPU")
if backend == AttentionBackendEnum.FLASHINFER and (
not is_blackwell() or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
# FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at
# CUDAGraphMode.NONE here because it derives an attention backend that
# does not support full cudagraphs
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
model_kwargs["attention_config"] = {"backend": backend.name}
compilation_config = CompilationConfig(
# Testing properties
custom_ops=custom_ops_list,
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
splitting_ops=splitting_ops,
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
log_matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(log_matches) == 1, log_holder.text
assert int(log_matches[0]) == matches.attention_fusion