[torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)

Signed-off-by: BadrBasowid <badr.basowid@gmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
BadrBasowid
2026-04-01 02:15:50 +08:00
committed by GitHub
parent 07edd551cc
commit 077a9a8e37
7 changed files with 275 additions and 212 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections import defaultdict
import pytest
import regex as re
@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
)
# Fetch match table from each worker via RPC and sum across workers.
worker_tables = llm.llm_engine.engine_core.collective_rpc(
"get_compilation_match_table"
)
combined: defaultdict[str, int] = defaultdict(int)
for table in worker_tables:
for k, v in table.items():
combined[k] += v
return dict(combined)
@pytest.fixture
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
@@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(full_compilation_config, model_name, **model_kwargs)
match_table = run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2, 3]
@@ -155,11 +166,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
else:
num_ranges_activated = num_compile_ranges
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
n_expected = tp_size * num_ranges_activated
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
if match_name != "attn_quant_fusion":
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
expected_matches = getattr(matches, match_name)
@@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}"
)
elif match_name == "attn_quant_fusion":
actual_match = match_table.get(match_name, 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."
)
else:
expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, (

View File

@@ -9,7 +9,10 @@ from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import TestFP8Layer, flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass
from vllm.compilation.passes.fusion.attn_quant_fusion import (
ATTN_OP,
AttnQuantFusionPass,
)
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fx_utils import find_op_nodes
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
@@ -384,7 +387,7 @@ def test_attention_quant_pattern(
# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
attn_pass = LazyInitPass(AttnQuantFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
@@ -434,7 +437,7 @@ def test_attention_quant_pattern(
# Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
# access the underlying `AttnQuantFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
# Check attention ops in the graph before and after fusion