# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Shared utilities for fusion tests (e.g. test_fusion_attn.py).""" from __future__ import annotations import itertools from collections.abc import Iterable from typing import Any, NamedTuple from tests.v1.attention.utils import AttentionBackendEnum from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CUDAGraphMode from vllm.platforms import current_platform is_blackwell = lambda: current_platform.is_device_capability_family(100) """Are we running on Blackwell, a lot of tests depend on it""" def has_cuda_graph_wrapper_metadata() -> bool: from importlib import import_module try: module = import_module("torch._inductor.utils") module.CUDAGraphWrapperMetadata # noqa B018 except AttributeError: return False return True class Matches(NamedTuple): attention_fusion: int = 0 allreduce_fusion: int = 0 sequence_parallel: int = 0 async_tp: int = 0 rms_quant_norm_fusion: int = 0 class ModelBackendTestCase(NamedTuple): model_name: str model_kwargs: dict[str, Any] backend: AttentionBackendEnum matches: Matches # E2E model test cases MODELS_FP8: list[ModelBackendTestCase] = [] MODELS_FP4: list[ModelBackendTestCase] = [] MODELS: list[ModelBackendTestCase] = [] # tp-only (unquantized) MODELS_GROUP_FP8: list[ModelBackendTestCase] = [] if current_platform.is_cuda(): MODELS_FP8 = [ ModelBackendTestCase( # Use smaller model for L40s in CI model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches( attention_fusion=32, allreduce_fusion=65, sequence_parallel=65, async_tp=128, ), ), ModelBackendTestCase( model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), # TODO FlashInfer attn broken on Hopper with kvcache=fp8: # https://github.com/vllm-project/vllm/issues/28568 backend=AttentionBackendEnum.FLASHINFER if is_blackwell() else AttentionBackendEnum.TRITON_ATTN, matches=Matches( attention_fusion=48, allreduce_fusion=96, sequence_parallel=96, async_tp=95, # mlp is moe, no fusion there ), ), ] MODELS_FP4 = [ ModelBackendTestCase( model_name="nvidia/Llama-3.1-8B-Instruct-FP4", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=AttentionBackendEnum.FLASHINFER, matches=Matches( attention_fusion=32, allreduce_fusion=65, sequence_parallel=65, async_tp=128, ), ), ] # TP only (unquantized models) MODELS = [ ModelBackendTestCase( model_name="meta-llama/Llama-3.1-8B-Instruct", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches( attention_fusion=0, allreduce_fusion=65, sequence_parallel=65, async_tp=128, ), ), ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches( attention_fusion=0, allreduce_fusion=97, sequence_parallel=97, async_tp=96, # MLP is MoE, half the fusions of dense ), ), ] MODELS_GROUP_FP8 = [ ModelBackendTestCase( model_name="Qwen/Qwen3-30B-A3B-FP8", model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches( rms_quant_norm_fusion=48, ), ), ] elif current_platform.is_rocm(): MODELS_FP8 = [ ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.TRITON_ATTN, matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.ROCM_ATTN, matches=Matches(attention_fusion=32), ), ModelBackendTestCase( model_name="amd/Llama-3.1-8B-Instruct-FP8-KV", model_kwargs=dict(max_model_len=1024), backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, matches=Matches(attention_fusion=32), ), ] # Custom ops toggle lists for parametrization CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"] CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"] CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"] def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]: """Generate all combinations of custom ops for parametrization.""" for op_list in itertools.product(*custom_ops_lists): yield ",".join(op_list) def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs): """Run a model with the given compilation config for E2E fusion tests.""" compilation_config = ( compile_config if isinstance(compile_config, CompilationConfig) else CompilationConfig(mode=compile_config) ) prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0) # Allow override from model_kwargs model_kwargs = {"tensor_parallel_size": 1, **model_kwargs} model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs} # No cudagraphs by default if compilation_config.cudagraph_mode is None: compilation_config.cudagraph_mode = CUDAGraphMode.NONE llm = LLM( model=model, compilation_config=compilation_config, **model_kwargs, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # Get the compile ranges split points after vllm config post init # in order to compute compile ranges correctly compilation_config.compile_ranges_split_points = ( llm.llm_engine.vllm_config.compilation_config.compile_ranges_split_points )