Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: ProExpertProg <luka.govedic@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
159 lines
6.2 KiB
Python
159 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import logging
|
|
|
|
import pytest
|
|
import regex as re
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode
|
|
|
|
from .common import FUSION_LOG_PATTERNS, AttentionBackendCase, Matches
|
|
|
|
|
|
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
|
"""Run a model with the given compilation config for E2E fusion tests."""
|
|
compilation_config = (
|
|
compile_config
|
|
if isinstance(compile_config, CompilationConfig)
|
|
else CompilationConfig(mode=compile_config)
|
|
)
|
|
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
]
|
|
sampling_params = SamplingParams(temperature=0)
|
|
# Allow override from model_kwargs
|
|
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
|
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
|
|
|
# No cudagraphs by default
|
|
if compilation_config.cudagraph_mode is None:
|
|
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
|
llm = LLM(
|
|
model=model,
|
|
compilation_config=compilation_config,
|
|
**model_kwargs,
|
|
)
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Print the outputs.
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
# Get the compile ranges split points after vllm config post init
|
|
# in order to compute compile ranges correctly
|
|
compilation_config.compile_ranges_split_points = (
|
|
llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
|
def run(
|
|
model_name: str,
|
|
matches: Matches,
|
|
model_kwargs: dict,
|
|
attn_backend: AttentionBackendCase,
|
|
compilation_config: dict,
|
|
matches_check: list[str],
|
|
use_deepgemm: bool = False,
|
|
tp_size: int = 1,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1" if use_deepgemm else "0")
|
|
|
|
# Disable, compile cache to make sure custom passes run.
|
|
# Otherwise, we can't verify fusion happened through the logs.
|
|
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
|
|
|
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
|
# Force spawn as it is more general.
|
|
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
|
|
|
model_kwargs = {**attn_backend.model_kwargs, **model_kwargs}
|
|
model_kwargs["attention_config"] = {"backend": attn_backend.backend.name}
|
|
model_kwargs["tensor_parallel_size"] = tp_size
|
|
|
|
# Always compile the full graph instead of piecewise
|
|
if not compilation_config["use_inductor_graph_partition"]:
|
|
compilation_config["splitting_ops"] = []
|
|
|
|
full_compilation_config = CompilationConfig(
|
|
cudagraph_mode=CUDAGraphMode.NONE,
|
|
mode=CompilationMode.VLLM_COMPILE,
|
|
inductor_compile_config={"force_disable_caches": True},
|
|
**compilation_config,
|
|
)
|
|
|
|
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
|
run_model(full_compilation_config, model_name, **model_kwargs)
|
|
|
|
num_compile_ranges = len(full_compilation_config.get_compile_ranges())
|
|
assert num_compile_ranges in [1, 2]
|
|
|
|
print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}")
|
|
print("Fusion results:")
|
|
|
|
# Iterate through all so printing happens before asserting
|
|
log_matches_dict = {}
|
|
for match_name, pattern in FUSION_LOG_PATTERNS.items():
|
|
log_matches_dict[match_name] = list(pattern.findall(log_holder.text))
|
|
print(f"- {match_name}={','.join(log_matches_dict[match_name])}")
|
|
|
|
# Now check the matches
|
|
for match_name in matches_check:
|
|
num_ranges_activated = (
|
|
1 if match_name == "ar_rms_fusion" else num_compile_ranges
|
|
)
|
|
n_expected = tp_size * num_ranges_activated
|
|
|
|
log_matches = list(int(ms) for ms in log_matches_dict[match_name])
|
|
assert len(log_matches) == n_expected, (
|
|
f"Could not find {n_expected} {match_name} "
|
|
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
|
)
|
|
|
|
expected_matches = getattr(matches, match_name)
|
|
|
|
if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check:
|
|
# AR+rms+quant takes precedence over rms+quant if activated.
|
|
# That means we get full matching where ar+rms+quant was not activated,
|
|
# and less where it was
|
|
assert sum(m == expected_matches for m in log_matches) == tp_size * (
|
|
num_ranges_activated - 1
|
|
), "Expecting full rms+quant fusion where ar+rms+quant not activated"
|
|
|
|
assert all(
|
|
expected_matches - matches.ar_rms_fusion <= m <= expected_matches
|
|
for m in log_matches
|
|
), (
|
|
f"Expecting at least {expected_matches - matches.ar_rms_fusion} "
|
|
f"where ar+rms+quant was activated"
|
|
)
|
|
else:
|
|
expected_matches_list = [expected_matches] * n_expected
|
|
assert sorted(log_matches) == expected_matches_list, (
|
|
f"{match_name} expected: {expected_matches_list}, "
|
|
f"found: {sorted(log_matches)}"
|
|
)
|
|
|
|
if match_name == "ar_rms_fusion":
|
|
log_matches = re.findall(
|
|
r"pass_manager.py:\d+] Skipping "
|
|
r".*AllReduceFusionPass.* with compile range",
|
|
log_holder.text,
|
|
)
|
|
|
|
n_expected = tp_size * (num_compile_ranges - num_ranges_activated)
|
|
assert len(log_matches) == n_expected, (
|
|
f'Could not find {n_expected} "Skipping AllReduceFusionPass" '
|
|
f"(found {len(log_matches)}) in:\n {log_holder.text}"
|
|
)
|
|
|
|
return run
|