[torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)
Signed-off-by: BadrBasowid <badr.basowid@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
|
||||
)
|
||||
|
||||
# Fetch match table from each worker via RPC and sum across workers.
|
||||
worker_tables = llm.llm_engine.engine_core.collective_rpc(
|
||||
"get_compilation_match_table"
|
||||
)
|
||||
combined: defaultdict[str, int] = defaultdict(int)
|
||||
for table in worker_tables:
|
||||
for k, v in table.items():
|
||||
combined[k] += v
|
||||
return dict(combined)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
@@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
match_table = run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
|
||||
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2, 3]
|
||||
@@ -155,11 +166,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
else:
|
||||
num_ranges_activated = num_compile_ranges
|
||||
|
||||
# TODO: Remove log counting in unit tests
|
||||
# once all matchers implement VllmFusionPatternMatcherPass
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
if match_name != "attn_quant_fusion":
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
expected_matches = getattr(matches, match_name)
|
||||
|
||||
@@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"{tp_size * (num_ranges_activated - 1)} large-range "
|
||||
f"entries (SP took precedence), found: {log_matches}"
|
||||
)
|
||||
|
||||
elif match_name == "attn_quant_fusion":
|
||||
actual_match = match_table.get(match_name, 0)
|
||||
assert actual_match == expected_matches * n_expected, (
|
||||
f"Could not find {expected_matches * n_expected} "
|
||||
f"{match_name} (found {actual_match})."
|
||||
)
|
||||
else:
|
||||
expected_matches_list = [expected_matches] * n_expected
|
||||
assert sorted(log_matches) == expected_matches_list, (
|
||||
|
||||
@@ -9,7 +9,10 @@ from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.utils import TestFP8Layer, flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.passes.fusion.attn_quant_fusion import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.passes.fusion.attn_quant_fusion import (
|
||||
ATTN_OP,
|
||||
AttnQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
@@ -384,7 +387,7 @@ def test_attention_quant_pattern(
|
||||
|
||||
# Create test backend with fusion passes enabled
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
|
||||
attn_pass = LazyInitPass(AttnQuantFusionPass, vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
|
||||
@@ -434,7 +437,7 @@ def test_attention_quant_pattern(
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Dynamic)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
# access the underlying `AttnQuantFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
# Check attention ops in the graph before and after fusion
|
||||
|
||||
Reference in New Issue
Block a user