# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging import pytest import regex as re from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode from .common import FUSION_LOG_PATTERNS, AttentionBackendCase, Matches def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): """Run a model with the given compilation config for E2E fusion tests.""" compilation_config = ( compile_config if isinstance(compile_config, CompilationConfig) else CompilationConfig(mode=compile_config) ) prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0) # Allow override from model_kwargs model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} # No cudagraphs by default if compilation_config.cudagraph_mode is None: compilation_config.cudagraph_mode = CUDAGraphMode.NONE llm = LLM( model=model, compilation_config=compilation_config, **model_kwargs, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # Get the compile ranges split points after vllm config post init # in order to compute compile ranges correctly compilation_config.compile_ranges_split_points = ( llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points ) @pytest.fixture def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn): def run( model_name: str, matches: Matches, model_kwargs: dict, attn_backend: AttentionBackendCase, compilation_config: dict, matches_check: list[str], use_deepgemm: bool = False, tp_size: int = 1, ): monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1" if use_deepgemm else "0") # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") # To capture subprocess logs, we need to know whether spawn or fork is used. # Force spawn as it is more general. monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") model_kwargs = {**attn_backend.model_kwargs, **model_kwargs} model_kwargs["attention_config"] = {"backend": attn_backend.backend.name} model_kwargs["tensor_parallel_size"] = tp_size # Always compile the full graph instead of piecewise if not compilation_config["use_inductor_graph_partition"]: compilation_config["splitting_ops"] = [] full_compilation_config = CompilationConfig( cudagraph_mode=CUDAGraphMode.NONE, mode=CompilationMode.VLLM_COMPILE, inductor_compile_config={"force_disable_caches": True}, **compilation_config, ) with caplog_mp_spawn(logging.DEBUG) as log_holder: run_model(full_compilation_config, model_name, **model_kwargs) num_compile_ranges = len(full_compilation_config.get_compile_ranges()) assert num_compile_ranges in [1, 2, 3] print(f"Compile ranges: {full_compilation_config.get_compile_ranges()}") print("Fusion results:") # Iterate through all so printing happens before asserting log_matches_dict = {} for match_name, pattern in FUSION_LOG_PATTERNS.items(): log_matches_dict[match_name] = list(pattern.findall(log_holder.text)) print(f"- {match_name}={','.join(log_matches_dict[match_name])}") # Now check the matches for match_name in matches_check: log_matches = list(int(ms) for ms in log_matches_dict[match_name]) # AR+RMS skips the largest range; SP skips the smallest. # When both are enabled, AR+RMS activation count is # model-dependent (hidden_size affects threshold), so derive # from log data. if ( match_name == "ar_rms_fusion" and "sequence_parallel" in matches_check and num_compile_ranges >= 2 ): assert ( len(log_matches) >= tp_size and len(log_matches) % tp_size == 0 ), ( f"Expected multiple of {tp_size} ar_rms log entries, " f"found {len(log_matches)}" ) num_ranges_activated = len(log_matches) // tp_size elif ( match_name in ("ar_rms_fusion", "sequence_parallel") and num_compile_ranges >= 2 ): num_ranges_activated = num_compile_ranges - 1 else: num_ranges_activated = num_compile_ranges n_expected = tp_size * num_ranges_activated assert len(log_matches) == n_expected, ( f"Could not find {n_expected} {match_name} " f"(found {len(log_matches)}) in:\n {log_holder.text}" ) expected_matches = getattr(matches, match_name) if match_name == "rms_quant_fusion" and "ar_rms_fusion" in matches_check: # AR+rms+quant takes precedence over rms+quant if activated. # That means we get full matching where ar+rms+quant was not # activated, and less where it was (only the smallest range). assert sum(m == expected_matches for m in log_matches) == tp_size * ( num_ranges_activated - 1 ), "Expecting full rms+quant fusion where ar+rms+quant not activated" assert all( expected_matches - matches.ar_rms_fusion <= m <= expected_matches for m in log_matches ), ( f"Expecting at least {expected_matches - matches.ar_rms_fusion} " f"where ar+rms+quant was activated" ) elif ( match_name == "async_tp" and "sequence_parallel" in matches_check and num_compile_ranges >= 2 ): # AsyncTP only finds patterns on ranges where SP ran. n_sp_ranges = num_compile_ranges - 1 assert ( sum(m == expected_matches for m in log_matches) == tp_size * n_sp_ranges ), ( f"Expecting {expected_matches} async_tp on " f"{tp_size * n_sp_ranges} SP-range entries, " f"found: {log_matches}" ) assert sum(m == 0 for m in log_matches) == tp_size, ( f"Expecting 0 async_tp on {tp_size} small-range entries " f"(no SP), found: {log_matches}" ) elif ( match_name == "ar_rms_fusion" and "sequence_parallel" in matches_check and num_compile_ranges >= 2 ): # SP consumes allreduce patterns first, so AR+RMS finds # full matches only on the smallest range (no SP). assert sum(m == expected_matches for m in log_matches) == tp_size, ( f"Expecting {expected_matches} ar_rms on " f"{tp_size} small-range entries, found: {log_matches}" ) assert sum(m == 0 for m in log_matches) == tp_size * ( num_ranges_activated - 1 ), ( f"Expecting 0 ar_rms on " f"{tp_size * (num_ranges_activated - 1)} large-range " f"entries (SP took precedence), found: {log_matches}" ) else: expected_matches_list = [expected_matches] * n_expected assert sorted(log_matches) == expected_matches_list, ( f"{match_name} expected: {expected_matches_list}, " f"found: {sorted(log_matches)}" ) if match_name == "ar_rms_fusion" and num_compile_ranges >= 2: log_matches = re.findall( r"pass_manager.py:\d+] Skipping " r".*AllReduceFusionPass.* with compile range", log_holder.text, ) n_expected = tp_size * (num_compile_ranges - num_ranges_activated) assert len(log_matches) == n_expected, ( f'Could not find {n_expected} "Skipping AllReduceFusionPass" ' f"(found {len(log_matches)}) in:\n {log_holder.text}" ) if match_name == "sequence_parallel" and num_compile_ranges >= 2: log_matches = re.findall( r"pass_manager.py:\d+] Skipping " r".*SequenceParallelismPass.* with compile range", log_holder.text, ) n_expected = tp_size * (num_compile_ranges - num_ranges_activated) assert len(log_matches) == n_expected, ( f'Could not find {n_expected} "Skipping SequenceParallelismPass" ' f"(found {len(log_matches)}) in:\n {log_holder.text}" ) return run