[torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)

Signed-off-by: BadrBasowid <badr.basowid@gmail.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
BadrBasowid
2026-04-01 02:15:50 +08:00
committed by GitHub
parent 07edd551cc
commit 077a9a8e37
7 changed files with 275 additions and 212 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections import defaultdict
import pytest
import regex as re
@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
)
# Fetch match table from each worker via RPC and sum across workers.
worker_tables = llm.llm_engine.engine_core.collective_rpc(
"get_compilation_match_table"
)
combined: defaultdict[str, int] = defaultdict(int)
for table in worker_tables:
for k, v in table.items():
combined[k] += v
return dict(combined)
@pytest.fixture
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
@@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(full_compilation_config, model_name, **model_kwargs)
match_table = run_model(full_compilation_config, model_name, **model_kwargs)
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
assert num_compile_ranges in [1, 2, 3]
@@ -155,11 +166,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
else:
num_ranges_activated = num_compile_ranges
# TODO: Remove log counting in unit tests
# once all matchers implement VllmFusionPatternMatcherPass
n_expected = tp_size * num_ranges_activated
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
if match_name != "attn_quant_fusion":
assert len(log_matches) == n_expected, (
f"Could not find {n_expected} {match_name} "
f"(found {len(log_matches)}) in:\n {log_holder.text}"
)
expected_matches = getattr(matches, match_name)
@@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
f"{tp_size * (num_ranges_activated - 1)} large-range "
f"entries (SP took precedence), found: {log_matches}"
)
elif match_name == "attn_quant_fusion":
actual_match = match_table.get(match_name, 0)
assert actual_match == expected_matches * n_expected, (
f"Could not find {expected_matches * n_expected} "
f"{match_name} (found {actual_match})."
)
else:
expected_matches_list = [expected_matches] * n_expected
assert sorted(log_matches) == expected_matches_list, (