[torch.compile] Refactor Attention Quant Fusion Pass and Remove Boilerplate (#37373)
Signed-off-by: BadrBasowid <badr.basowid@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
@@ -52,6 +53,16 @@ def run_model(compile_config: int | CompilationConfig, model: str, **model_kwarg
|
||||
llm.llm_engine.vllm_config.compilation_config.compile_ranges_endpoints
|
||||
)
|
||||
|
||||
# Fetch match table from each worker via RPC and sum across workers.
|
||||
worker_tables = llm.llm_engine.engine_core.collective_rpc(
|
||||
"get_compilation_match_table"
|
||||
)
|
||||
combined: defaultdict[str, int] = defaultdict(int)
|
||||
for table in worker_tables:
|
||||
for k, v in table.items():
|
||||
combined[k] += v
|
||||
return dict(combined)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
@@ -113,7 +124,7 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
match_table = run_model(full_compilation_config, model_name, **model_kwargs)
|
||||
|
||||
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
||||
assert num_compile_ranges in [1, 2, 3]
|
||||
@@ -155,11 +166,14 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
else:
|
||||
num_ranges_activated = num_compile_ranges
|
||||
|
||||
# TODO: Remove log counting in unit tests
|
||||
# once all matchers implement VllmFusionPatternMatcherPass
|
||||
n_expected = tp_size * num_ranges_activated
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
if match_name != "attn_quant_fusion":
|
||||
assert len(log_matches) == n_expected, (
|
||||
f"Could not find {n_expected} {match_name} "
|
||||
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
||||
)
|
||||
|
||||
expected_matches = getattr(matches, match_name)
|
||||
|
||||
@@ -215,6 +229,13 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
f"{tp_size * (num_ranges_activated - 1)} large-range "
|
||||
f"entries (SP took precedence), found: {log_matches}"
|
||||
)
|
||||
|
||||
elif match_name == "attn_quant_fusion":
|
||||
actual_match = match_table.get(match_name, 0)
|
||||
assert actual_match == expected_matches * n_expected, (
|
||||
f"Could not find {expected_matches * n_expected} "
|
||||
f"{match_name} (found {actual_match})."
|
||||
)
|
||||
else:
|
||||
expected_matches_list = [expected_matches] * n_expected
|
||||
assert sorted(log_matches) == expected_matches_list, (
|
||||
|
||||
Reference in New Issue
Block a user