diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index ba11fca74..de56cbac8 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -43,6 +43,7 @@ from common import ( ModelParameterSweep, ParameterSweep, ResultsFormatter, + batch_spec_sort_key, is_mla_backend, ) @@ -218,10 +219,13 @@ def run_model_parameter_sweep( by_param_and_spec[key].append(r) break - # Sort by param value then spec + # Sort by param value then spec (batch_size, q_len, kv_len) sorted_keys = sorted( by_param_and_spec.keys(), - key=lambda x: (int(x[0]) if x[0].isdigit() else x[0], x[1]), + key=lambda x: ( + int(x[0]) if x[0].isdigit() else x[0], + batch_spec_sort_key(x[1]), + ), ) current_param_value = None @@ -330,7 +334,7 @@ def run_parameter_sweep( by_spec[spec] = [] by_spec[spec].append(r) - for spec in sorted(by_spec.keys()): + for spec in sorted(by_spec.keys(), key=batch_spec_sort_key): results = by_spec[spec] best = min(results, key=lambda r: r.mean_time) console.print( @@ -496,15 +500,18 @@ def main(): if "description" in yaml_config: console.print(f"[dim]{yaml_config['description']}[/]") - # Override args with YAML values - # (YAML takes precedence unless CLI arg was explicitly set) - # Backend(s) - if "backend" in yaml_config: - args.backend = yaml_config["backend"] - args.backends = None - elif "backends" in yaml_config: - args.backends = yaml_config["backends"] - args.backend = None + # Override args with YAML values, but CLI args take precedence + # Check if CLI provided backends (they would be non-None and not default) + cli_backends_provided = args.backends is not None or args.backend is not None + + # Backend(s) - only use YAML if CLI didn't specify + if not cli_backends_provided: + if "backend" in yaml_config: + args.backend = yaml_config["backend"] + args.backends = None + elif "backends" in yaml_config: + args.backends = yaml_config["backends"] + args.backend = None # Check for special modes if "mode" in yaml_config: @@ -544,13 +551,15 @@ def main(): args.num_kv_heads = model.get("num_kv_heads", args.num_kv_heads) args.block_size = model.get("block_size", args.block_size) - # Benchmark settings - if "benchmark" in yaml_config: - bench = yaml_config["benchmark"] - args.device = bench.get("device", args.device) - args.repeats = bench.get("repeats", args.repeats) - args.warmup_iters = bench.get("warmup_iters", args.warmup_iters) - args.profile_memory = bench.get("profile_memory", args.profile_memory) + # Benchmark settings (top-level keys) + if "device" in yaml_config: + args.device = yaml_config["device"] + if "repeats" in yaml_config: + args.repeats = yaml_config["repeats"] + if "warmup_iters" in yaml_config: + args.warmup_iters = yaml_config["warmup_iters"] + if "profile_memory" in yaml_config: + args.profile_memory = yaml_config["profile_memory"] # Parameter sweep configuration if "parameter_sweep" in yaml_config: diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 190b2f977..1de8bb0a5 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -16,13 +16,32 @@ from batch_spec import get_batch_type, parse_batch_spec from rich.console import Console from rich.table import Table + +def batch_spec_sort_key(spec: str) -> tuple[int, int, int]: + """ + Extract sorting key from batch spec: (batch_size, max_q_len, max_kv_len). + + This ensures results are sorted by batch size first, then query length, + then sequence length, rather than alphabetically. + """ + try: + requests = parse_batch_spec(spec) + batch_size = len(requests) + max_q_len = max(r.q_len for r in requests) if requests else 0 + max_kv_len = max(r.kv_len for r in requests) if requests else 0 + return (batch_size, max_q_len, max_kv_len) + except Exception: + # Fallback for unparseable specs + return (0, 0, 0) + + # Mock classes for vLLM attention infrastructure class MockHfConfig: """Mock HuggingFace config that satisfies vLLM's requirements.""" - def __init__(self, mla_dims: dict): + def __init__(self, mla_dims: dict, index_topk: int | None = None): self.num_attention_heads = mla_dims["num_q_heads"] self.num_key_value_heads = mla_dims["num_kv_heads"] self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"] @@ -33,6 +52,8 @@ class MockHfConfig: self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"] self.v_head_dim = mla_dims["v_head_dim"] self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"] + if index_topk is not None: + self.index_topk = index_topk def get_text_config(self): return self @@ -83,6 +104,38 @@ class MockKVBProj: return (result,) # Return as tuple to match ColumnParallelLinear API +class MockIndexer: + """Mock Indexer for sparse MLA backends. + + Provides topk_indices_buffer that sparse MLA backends use to determine + which KV cache slots to attend to for each token. + """ + + def __init__( + self, + max_num_tokens: int, + topk_tokens: int, + device: torch.device, + ): + self.topk_tokens = topk_tokens + self.topk_indices_buffer = torch.zeros( + (max_num_tokens, topk_tokens), + dtype=torch.int32, + device=device, + ) + + def fill_random_indices(self, num_tokens: int, max_kv_len: int): + """Fill topk_indices_buffer with random valid indices for benchmarking.""" + indices = torch.randint( + 0, + max_kv_len, + (num_tokens, self.topk_tokens), + dtype=torch.int32, + device=self.topk_indices_buffer.device, + ) + self.topk_indices_buffer[:num_tokens] = indices + + class MockLayer(AttentionLayerBase): """Mock attention layer with scale parameters and impl. @@ -327,6 +380,9 @@ class ResultsFormatter: specs_order.append(spec) by_spec[spec][r.config.backend] = r + # Sort specs by (batch_size, q_len, kv_len) instead of alphabetically + specs_order = sorted(by_spec.keys(), key=batch_spec_sort_key) + # Create shortened backend names for display def shorten_backend_name(name: str) -> str: """Shorten long backend names for table display.""" @@ -493,10 +549,11 @@ def get_attention_scale(head_dim: int) -> float: def is_mla_backend(backend: str) -> bool: """ - Check if backend is an MLA backend using the backend's is_mla() property. + Check if backend is an MLA backend using the AttentionBackendEnum. Args: - backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA") + backend: Backend name matching AttentionBackendEnum exactly + (e.g., "FLASHMLA_SPARSE") Returns: True if the backend is an MLA backend, False otherwise @@ -504,7 +561,8 @@ def is_mla_backend(backend: str) -> bool: from vllm.v1.attention.backends.registry import AttentionBackendEnum try: - backend_class = AttentionBackendEnum[backend.upper()].get_class() + backend_enum = AttentionBackendEnum[backend] + backend_class = backend_enum.get_class() return backend_class.is_mla() - except (KeyError, ValueError, ImportError): + except (KeyError, ValueError, ImportError, AttributeError): return False diff --git a/benchmarks/attention_benchmarks/configs/mla_decode.yaml b/benchmarks/attention_benchmarks/configs/mla_decode.yaml index aaf4eec9b..d758654db 100644 --- a/benchmarks/attention_benchmarks/configs/mla_decode.yaml +++ b/benchmarks/attention_benchmarks/configs/mla_decode.yaml @@ -3,7 +3,7 @@ model: name: "deepseek-v3" num_layers: 60 - num_q_heads: 128 + num_q_heads: 128 # Base value, can be swept for TP simulation num_kv_heads: 1 # MLA uses single latent KV head_dim: 576 kv_lora_rank: 512 @@ -12,6 +12,13 @@ model: v_head_dim: 128 block_size: 128 # CUTLASS MLA and FlashAttn MLA use 128 +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + batch_specs: # Small batches, varying sequence lengths - "16q1s512" # 16 requests, 512 KV cache @@ -34,28 +41,30 @@ batch_specs: # Very large batches - "128q1s1k" # 128 requests, 1k KV cache - "128q1s2k" # 128 requests, 2k KV cache + - "128q1s4k" # 128 requests, 4k KV cache + - "128q1s8k" # 128 requests, 8k KV cache # Long context - "32q1s16k" # 32 requests, 16k KV cache - "32q1s32k" # 32 requests, 32k KV cache backends: - - cutlass_mla - - flashinfer_mla - - flashattn_mla # Hopper only - - flashmla # Hopper only + - CUTLASS_MLA + - FLASHINFER_MLA + - FLASH_ATTN_MLA # Hopper only + - FLASHMLA # Hopper only device: "cuda:0" -repeats: 5 -warmup_iters: 3 +repeats: 100 +warmup_iters: 10 profile_memory: true # Backend-specific tuning -cutlass_mla: +CUTLASS_MLA: num_kv_splits: auto # or specific value like 4, 8, 16 -flashattn_mla: +FLASH_ATTN_MLA: reorder_batch_threshold: 512 -flashmla: +FLASHMLA: reorder_batch_threshold: 1 diff --git a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml index ad3c0dced..b555d90cb 100644 --- a/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml +++ b/benchmarks/attention_benchmarks/configs/mla_mixed_batch.yaml @@ -45,10 +45,10 @@ batch_specs: - "4q4k_60q1s4k" # 4 prefill + 60 decode backends: - - cutlass_mla - - flashinfer_mla - - flashattn_mla # Hopper only - - flashmla # Hopper only + - CUTLASS_MLA + - FLASHINFER_MLA + - FLASH_ATTN_MLA # Hopper only + - FLASHMLA # Hopper only device: "cuda:0" repeats: 5 diff --git a/benchmarks/attention_benchmarks/configs/mla_prefill.yaml b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml new file mode 100644 index 000000000..ef6b2cb07 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml @@ -0,0 +1,62 @@ +# MLA prefill-only benchmark configuration for sparse backends + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Pure prefill + - "1q512" + - "1q1k" + - "1q2k" + - "1q4k" + - "1q8k" + + # Batched pure prefill + - "2q512" + - "2q1k" + - "2q2k" + - "2q4k" + - "2q8k" + - "4q512" + - "4q1k" + - "4q2k" + - "4q4k" + - "4q8k" + - "8q512" + - "8q1k" + - "8q2k" + - "8q4k" + - "8q8k" + + # Extend + - "1q512s4k" + - "1q512s8k" + - "1q1ks8k" + - "1q2ks8k" + - "1q2ks16k" + - "1q4ks16k" + +backends: + - FLASHMLA_SPARSE + - FLASHINFER_MLA_SPARSE + +device: "cuda:0" +repeats: 10 +warmup_iters: 3 +profile_memory: true diff --git a/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml b/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml index 1ea0a12b5..0d76ef0a3 100644 --- a/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml +++ b/benchmarks/attention_benchmarks/configs/reorder_threshold.yaml @@ -6,7 +6,7 @@ description: "Decode vs Prefill pipeline crossover analysis" # Test FlashAttn MLA -backend: flashattn_mla +backend: FLASH_ATTN_MLA # Mode: decode_vs_prefill comparison (special sweep mode) # For each batch spec, we'll test both decode and prefill pipelines @@ -62,11 +62,10 @@ model: block_size: 128 # Benchmark settings -benchmark: - device: "cuda:0" - repeats: 15 # More repeats for spec decode variance - warmup_iters: 5 - profile_memory: false +device: "cuda:0" +repeats: 15 # More repeats for spec decode variance +warmup_iters: 5 +profile_memory: false # Output output: diff --git a/benchmarks/attention_benchmarks/configs/speculative_decode.yaml b/benchmarks/attention_benchmarks/configs/speculative_decode.yaml index 56d2428fe..47b6d3604 100644 --- a/benchmarks/attention_benchmarks/configs/speculative_decode.yaml +++ b/benchmarks/attention_benchmarks/configs/speculative_decode.yaml @@ -41,18 +41,17 @@ batch_specs: # Backends that support query length > 1 backends: - - flashattn_mla # reorder_batch_threshold = 512 - - flashmla # reorder_batch_threshold = 1 (tunable) + - FLASH_ATTN_MLA # reorder_batch_threshold = 512 + - FLASHMLA # reorder_batch_threshold = 1 (tunable) # FlashInfer-MLA also supports uniform spec-as-decode but with different mechanism -# - flashinfer_mla +# - FLASHINFER_MLA # Benchmark settings -benchmark: - device: "cuda:0" - repeats: 10 # More repeats for statistical significance - warmup_iters: 5 - profile_memory: false +device: "cuda:0" +repeats: 10 # More repeats for statistical significance +warmup_iters: 5 +profile_memory: false # Test these threshold values for optimization parameter_sweep: diff --git a/benchmarks/attention_benchmarks/configs/standard_attention.yaml b/benchmarks/attention_benchmarks/configs/standard_attention.yaml index 591db6837..deb5a4b27 100644 --- a/benchmarks/attention_benchmarks/configs/standard_attention.yaml +++ b/benchmarks/attention_benchmarks/configs/standard_attention.yaml @@ -36,11 +36,11 @@ batch_specs: - "q1ks2k" # 1k query, 2k sequence - "2q1ks4k" # 2 requests: 1k query, 4k sequence -# Available backends: flash, triton, flashinfer +# Available backends: FLASH_ATTN, TRITON_ATTN, FLASHINFER backends: - - flash - - triton - - flashinfer + - FLASH_ATTN + - TRITON_ATTN + - FLASHINFER device: "cuda:0" repeats: 5 diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index 2c6c3aaac..ffcfa4572 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -8,14 +8,13 @@ This module provides helpers for running MLA backends without needing full VllmConfig integration. """ -import importlib - import numpy as np import torch from batch_spec import parse_batch_spec from common import ( BenchmarkResult, MockHfConfig, + MockIndexer, MockKVBProj, MockLayer, setup_mla_dims, @@ -62,6 +61,7 @@ def create_minimal_vllm_config( block_size: int = 128, max_num_seqs: int = 256, mla_dims: dict | None = None, + index_topk: int | None = None, ) -> VllmConfig: """ Create minimal VllmConfig for MLA benchmarks. @@ -73,6 +73,8 @@ def create_minimal_vllm_config( max_num_seqs: Maximum number of sequences mla_dims: Optional custom MLA dimensions dict. If not provided, uses setup_mla_dims(model_name) + index_topk: Optional topk value for sparse MLA backends. If provided, + the config will include index_topk for sparse attention. Returns: VllmConfig for benchmarking @@ -82,7 +84,7 @@ def create_minimal_vllm_config( mla_dims = setup_mla_dims(model_name) # Create mock HF config first (avoids downloading from HuggingFace) - mock_hf_config = MockHfConfig(mla_dims) + mock_hf_config = MockHfConfig(mla_dims, index_topk=index_topk) # Create a temporary minimal config.json to avoid HF downloads # This ensures consistent ModelConfig construction without network access @@ -120,16 +122,12 @@ def create_minimal_vllm_config( seed=0, max_model_len=32768, quantization=None, - quantization_param_path=None, enforce_eager=False, - max_context_len_to_capture=None, - max_seq_len_to_capture=8192, max_logprobs=20, disable_sliding_window=False, skip_tokenizer_init=True, served_model_name=None, limit_mm_per_prompt=None, - use_async_output_proc=True, config_format="auto", ) finally: @@ -180,56 +178,65 @@ def create_minimal_vllm_config( # ============================================================================ -# Backend name to class name prefix mapping -_BACKEND_NAME_MAP = { - "flashattn_mla": "FlashAttnMLA", - "flashmla": "FlashMLA", - "flashinfer_mla": "FlashInferMLA", - "cutlass_mla": "CutlassMLA", -} - -# Special properties that differ from defaults +# Backend-specific properties that can't be inferred from the backend class +# Keys are AttentionBackendEnum names (uppercase) _BACKEND_PROPERTIES = { - "flashmla": { + "FLASHMLA": { "query_format": "concat", # Single concatenated tensor (vs tuple) - "block_size": 64, # FlashMLA uses fixed block size }, - "flashinfer_mla": { - "block_size": 64, # FlashInfer MLA only supports 32 or 64 + "FLASHMLA_SPARSE": { + "query_format": "concat", # Single concatenated tensor (vs tuple) }, } def _get_backend_config(backend: str) -> dict: """ - Get backend configuration using naming conventions. + Get backend configuration from AttentionBackendEnum. - All MLA backends follow the pattern: - - Module: vllm.v1.attention.backends.mla.{backend} - - Impl: {Name}Impl - - Metadata: {Name}Metadata (or MLACommonMetadata) - - DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata) - - MetadataBuilder: {Name}MetadataBuilder + Uses the registry to get the backend class and extract configuration + from its methods (get_impl_cls, get_builder_cls, is_sparse, etc.). + + Args: + backend: Backend name matching AttentionBackendEnum exactly + (e.g., "FLASHMLA_SPARSE") + + Returns: + Dict with backend configuration """ - if backend not in _BACKEND_NAME_MAP: - raise ValueError(f"Unknown backend: {backend}") + from vllm.v1.attention.backends.registry import AttentionBackendEnum - name = _BACKEND_NAME_MAP[backend] + try: + backend_enum = AttentionBackendEnum[backend] + backend_class = backend_enum.get_class() + except (KeyError, ValueError) as e: + valid_backends = [e.name for e in AttentionBackendEnum if e.name != "CUSTOM"] + raise ValueError( + f"Unknown backend: {backend}. " + f"Valid MLA backends: {[b for b in valid_backends if 'MLA' in b]}" + ) from e + + # Get block size from backend class + block_sizes = backend_class.get_supported_kernel_block_sizes() + # Use first supported block size (backends typically support one for MLA) + block_size = block_sizes[0] if block_sizes else None + if hasattr(block_size, "value"): + # Handle MultipleOf enum + block_size = None + + # Check if sparse via class method if available + is_sparse = getattr(backend_class, "is_sparse", lambda: False)() + + # Get properties that can't be inferred props = _BACKEND_PROPERTIES.get(backend, {}) - # Check if backend uses common metadata (FlashInfer, CUTLASS) - uses_common = backend in ("flashinfer_mla", "cutlass_mla") - return { - "module": f"vllm.v1.attention.backends.mla.{backend}", - "impl_class": f"{name}Impl", - "metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata", - "decode_metadata_class": "MLACommonDecodeMetadata" - if uses_common - else f"{name}DecodeMetadata", - "builder_class": f"{name}MetadataBuilder", + "backend_class": backend_class, + "impl_class": backend_class.get_impl_cls(), + "builder_class": backend_class.get_builder_cls(), "query_format": props.get("query_format", "tuple"), - "block_size": props.get("block_size", None), + "block_size": block_size, + "is_sparse": is_sparse, } @@ -447,22 +454,26 @@ def _create_backend_impl( mla_dims: dict, vllm_config: VllmConfig, device: torch.device, + max_num_tokens: int = 8192, + index_topk: int | None = None, ): """ Create backend implementation instance. Args: - backend_cfg: Backend configuration dict + backend_cfg: Backend configuration dict from _get_backend_config() mla_dims: MLA dimension configuration vllm_config: VllmConfig instance device: Target device + max_num_tokens: Maximum number of tokens for sparse indexer buffer + index_topk: Topk value for sparse MLA backends Returns: - Tuple of (impl, layer, builder_instance) + Tuple of (impl, layer, builder_instance, indexer) """ - # Import backend classes - backend_module = importlib.import_module(backend_cfg["module"]) - impl_class = getattr(backend_module, backend_cfg["impl_class"]) + # Get classes from backend config (already resolved by _get_backend_config) + impl_class = backend_cfg["impl_class"] + builder_class = backend_cfg["builder_class"] # Calculate scale scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]) @@ -474,26 +485,44 @@ def _create_backend_impl( v_head_dim=mla_dims["v_head_dim"], ) + # Create indexer for sparse backends + indexer = None + if backend_cfg.get("is_sparse", False): + if index_topk is None: + index_topk = 2048 # Default topk for sparse MLA + indexer = MockIndexer( + max_num_tokens=max_num_tokens, + topk_tokens=index_topk, + device=device, + ) + + # Build impl kwargs + impl_kwargs = { + "num_heads": mla_dims["num_q_heads"], + "head_size": mla_dims["head_dim"], + "scale": scale, + "num_kv_heads": mla_dims["num_kv_heads"], + "alibi_slopes": None, + "sliding_window": None, + "kv_cache_dtype": "auto", + "logits_soft_cap": None, + "attn_type": "decoder", + "kv_sharing_target_layer_name": None, + "q_lora_rank": None, + "kv_lora_rank": mla_dims["kv_lora_rank"], + "qk_nope_head_dim": mla_dims["qk_nope_head_dim"], + "qk_rope_head_dim": mla_dims["qk_rope_head_dim"], + "qk_head_dim": mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"], + "v_head_dim": mla_dims["v_head_dim"], + "kv_b_proj": mock_kv_b_proj, + } + + # Add indexer for sparse backends + if indexer is not None: + impl_kwargs["indexer"] = indexer + # Create impl - impl = impl_class( - num_heads=mla_dims["num_q_heads"], - head_size=mla_dims["head_dim"], - scale=scale, - num_kv_heads=mla_dims["num_kv_heads"], - alibi_slopes=None, - sliding_window=None, - kv_cache_dtype="auto", - logits_soft_cap=None, - attn_type="decoder", - kv_sharing_target_layer_name=None, - q_lora_rank=None, - kv_lora_rank=mla_dims["kv_lora_rank"], - qk_nope_head_dim=mla_dims["qk_nope_head_dim"], - qk_rope_head_dim=mla_dims["qk_rope_head_dim"], - qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"], - v_head_dim=mla_dims["v_head_dim"], - kv_b_proj=mock_kv_b_proj, - ) + impl = impl_class(**impl_kwargs) # Initialize DCP attributes if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1): @@ -515,9 +544,7 @@ def _create_backend_impl( # Create builder instance if needed builder_instance = None - if backend_cfg["builder_class"]: - builder_class = getattr(backend_module, backend_cfg["builder_class"]) - + if builder_class: # Populate static_forward_context so builder can find the layer # MockLayer inherits from AttentionLayerBase, so isinstance checks pass vllm_config.compilation_config.static_forward_context = {"placeholder": layer} @@ -529,7 +556,7 @@ def _create_backend_impl( device=device, ) - return impl, layer, builder_instance + return impl, layer, builder_instance, indexer # ============================================================================ @@ -594,6 +621,7 @@ def _run_single_benchmark( backend_cfg: dict, mla_dims: dict, device: torch.device, + indexer=None, ) -> BenchmarkResult: """ Run a single benchmark iteration. @@ -606,6 +634,7 @@ def _run_single_benchmark( backend_cfg: Backend configuration dict mla_dims: MLA dimension configuration device: Target device + indexer: Optional MockIndexer for sparse backends Returns: BenchmarkResult with timing statistics @@ -613,7 +642,9 @@ def _run_single_benchmark( # Parse batch spec requests = parse_batch_spec(config.batch_spec) q_lens = [r.q_len for r in requests] + kv_lens = [r.kv_len for r in requests] total_q = sum(q_lens) + max_kv_len = max(kv_lens) # Determine block size block_size = backend_cfg["block_size"] or config.block_size @@ -641,8 +672,16 @@ def _run_single_benchmark( torch.bfloat16, ) - # Determine which forward method to use based on metadata - if metadata.decode is not None: + # Fill indexer with random indices for sparse backends + is_sparse = backend_cfg.get("is_sparse", False) + if is_sparse and indexer is not None: + indexer.fill_random_indices(total_q, max_kv_len) + + # Determine which forward method to use + if is_sparse: + # Sparse backends use forward_mqa + forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) + elif metadata.decode is not None: forward_fn = lambda: impl._forward_decode( decode_inputs, kv_cache, metadata, layer ) @@ -693,11 +732,13 @@ def _run_single_benchmark( def _run_mla_benchmark_batched( backend: str, configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] + index_topk: int = 2048, ) -> list[BenchmarkResult]: """ Unified batched MLA benchmark runner for all backends. - Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla + Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla, + flashinfer_mla_sparse, flashmla_sparse This function reuses backend initialization across multiple benchmarks to avoid setup/teardown overhead. @@ -707,6 +748,7 @@ def _run_mla_benchmark_batched( configs_with_params: List of (config, threshold, num_splits) tuples - threshold: reorder_batch_threshold (FlashAttn/FlashMLA only) - num_splits: num_kv_splits (CUTLASS only) + index_topk: Topk value for sparse MLA backends (default 2048) Returns: List of BenchmarkResult objects @@ -730,19 +772,27 @@ def _run_mla_benchmark_batched( if mla_dims is None: mla_dims = setup_mla_dims("deepseek-v3") + # Determine if this is a sparse backend + is_sparse = backend_cfg.get("is_sparse", False) + # Create and set vLLM config for MLA (reused across all benchmarks) vllm_config = create_minimal_vllm_config( model_name="deepseek-v3", # Used only for model path block_size=block_size, mla_dims=mla_dims, # Use custom dims from config or default + index_topk=index_topk if is_sparse else None, ) results = [] with set_current_vllm_config(vllm_config): - # Create backend impl, layer, and builder (reused across benchmarks) - impl, layer, builder_instance = _create_backend_impl( - backend_cfg, mla_dims, vllm_config, device + # Create backend impl, layer, builder, and indexer (reused across benchmarks) + impl, layer, builder_instance, indexer = _create_backend_impl( + backend_cfg, + mla_dims, + vllm_config, + device, + index_topk=index_topk if is_sparse else None, ) # Run each benchmark with the shared impl @@ -768,6 +818,7 @@ def _run_mla_benchmark_batched( backend_cfg, mla_dims, device, + indexer=indexer, ) results.append(result) @@ -793,20 +844,24 @@ def run_mla_benchmark( config, reorder_batch_threshold: int | None = None, num_kv_splits: int | None = None, + index_topk: int = 2048, ) -> BenchmarkResult | list[BenchmarkResult]: """ Unified MLA benchmark runner for all backends. - Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla + Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla, + flashinfer_mla_sparse, flashmla_sparse Always uses batched execution internally for optimal performance. Args: - backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla) + backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla, + flashinfer_mla_sparse, flashmla_sparse) config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA (single config mode only) num_kv_splits: Number of KV splits for CUTLASS (single config mode only) + index_topk: Topk value for sparse MLA backends (default 2048) Returns: BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) @@ -816,9 +871,9 @@ def run_mla_benchmark( # Already in batched format if len(config) > 0 and isinstance(config[0], tuple): # Format: [(cfg, param), ...] where param is threshold or num_splits - if backend in ("flashattn_mla", "flashmla"): + if backend in ("flashattn_mla", "flashmla", "flashmla_sparse"): configs_with_params = [(cfg, param, None) for cfg, param in config] - else: # cutlass_mla or flashinfer_mla + else: # cutlass_mla, flashinfer_mla, or sparse backends configs_with_params = [(cfg, None, param) for cfg, param in config] else: # Format: [cfg, ...] - just configs @@ -830,7 +885,7 @@ def run_mla_benchmark( return_single = True # Use unified batched execution - results = _run_mla_benchmark_batched(backend, configs_with_params) + results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk) # Return single result or list based on input return results[0] if return_single else results diff --git a/benchmarks/attention_benchmarks/runner.py b/benchmarks/attention_benchmarks/runner.py index 79bfca681..6457a599a 100644 --- a/benchmarks/attention_benchmarks/runner.py +++ b/benchmarks/attention_benchmarks/runner.py @@ -40,29 +40,29 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec # ============================================================================ -_BACKEND_CONFIG = { - "flash": { - "module": "vllm.v1.attention.backends.flash_attn", - "backend_class": "FlashAttentionBackend", - }, - "triton": { - "module": "vllm.v1.attention.backends.triton_attn", - "backend_class": "TritonAttentionBackend", - }, - "flashinfer": { - "module": "vllm.v1.attention.backends.flashinfer", - "backend_class": "FlashInferBackend", - }, -} - - def _get_backend_config(backend: str) -> dict: - if backend not in _BACKEND_CONFIG: + """ + Get backend configuration from AttentionBackendEnum. + + Args: + backend: Backend name matching AttentionBackendEnum exactly + (e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER") + + Returns: + Dict with backend_class + """ + from vllm.v1.attention.backends.registry import AttentionBackendEnum + + try: + backend_enum = AttentionBackendEnum[backend] + backend_class = backend_enum.get_class() + except (KeyError, ValueError) as e: + valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"] raise ValueError( - f"Unknown backend: {backend}. " - f"Available: {', '.join(_BACKEND_CONFIG.keys())}" - ) - return _BACKEND_CONFIG[backend] + f"Unknown backend: {backend}. Valid backends: {valid_backends}" + ) from e + + return {"backend_class": backend_class} @contextmanager @@ -205,10 +205,7 @@ def _create_backend_impl( dtype: torch.dtype, ): """Create backend implementation instance.""" - import importlib - - backend_module = importlib.import_module(backend_cfg["module"]) - backend_class = getattr(backend_module, backend_cfg["backend_class"]) + backend_class = backend_cfg["backend_class"] scale = get_attention_scale(config.head_dim) @@ -247,7 +244,7 @@ def _create_metadata_builder( # Flashinfer needs get_per_layer_parameters mocked since we don't have # real model layers registered - if backend_name == "flashinfer": + if backend_name == "FLASHINFER": import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -438,7 +435,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: """ Run standard attention benchmark with real kernels. - Supports: flash, triton, flashinfer + Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER Args: config: Benchmark configuration @@ -453,7 +450,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult: requests = parse_batch_spec(config.batch_spec) - if config.backend == "flashinfer": + if config.backend == "FLASHINFER": requests = reorder_for_flashinfer(requests) q_lens = [r.q_len for r in requests] diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index b551e31db..3244ce7cc 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -128,6 +128,7 @@ Priority is **1 = highest** (tried first). | 4 | `FLASHMLA` | | 5 | `TRITON_MLA` | | 6 | `FLASHMLA_SPARSE` | +| 7 | `FLASHINFER_MLA_SPARSE` | **Ampere/Hopper (SM 8.x-9.x):** @@ -204,6 +205,7 @@ configuration. |---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | +| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index e4ffd12ca..fe9ca8289 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Unit tests for the FlashMLA sparse backend utilities.""" +"""Unit tests for the sparse MLA backends and utilities.""" import math from types import MethodType, SimpleNamespace -import numpy as np import pytest import torch @@ -25,6 +24,9 @@ from vllm.config import set_current_vllm_config from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.mla.flashinfer_mla_sparse import ( + FlashInferMLASparseBackend, +) from vllm.v1.attention.backends.mla.flashmla_sparse import ( FlashMLASparseBackend, triton_convert_req_index_to_global_index, @@ -156,32 +158,48 @@ def _quantize_dequantize_fp8_ds_mla( return dequant_kv_c, dequant_k_pe -@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) -@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) -@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) -@pytest.mark.skipif( - torch.cuda.get_device_capability() < (9, 0), - reason="FlashMLASparseBackend requires CUDA 9.0 or higher", +@pytest.mark.parametrize( + "backend_cls", + [FlashMLASparseBackend, FlashInferMLASparseBackend], + ids=["FlashMLA", "FlashInfer"], ) +@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +@pytest.mark.parametrize("block_size", [32, 64]) def test_sparse_backend_decode_correctness( default_vllm_config, dist_init, + backend_cls, batch_name, kv_cache_dtype, tensor_parallel_size, + block_size, workspace_init, ): - if current_platform.is_rocm(): - pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.") + if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes: + pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}") - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for sparse MLA decode test") + supported_block_sizes = backend_cls.get_supported_kernel_block_sizes() + if block_size not in supported_block_sizes: + pytest.skip( + f"{backend_cls.get_name()} does not support block_size={block_size}" + ) + + if backend_cls == FlashMLASparseBackend: + ok, reason = flashmla.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) + elif backend_cls == FlashInferMLASparseBackend: + if not current_platform.has_device_capability(100): + pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher") + + batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] + use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla" device = torch.device("cuda") dtype = torch.bfloat16 - batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] - # Model hyper-parameters (kept intentionally small for the unit test) total_num_heads = 128 # Compute per-rank heads for simulated TP @@ -192,11 +210,10 @@ def test_sparse_backend_decode_correctness( qk_rope_head_dim = 64 v_head_dim = 128 head_size = kv_lora_rank + qk_rope_head_dim - topk_tokens = 2048 + topk_tokens = 128 max_seqlen = max(batch_spec.seq_lens) total_cache_tokens = sum(batch_spec.seq_lens) - block_size = 64 # Note: We use TP=1 to avoid multi-GPU requirements in CI. # The test simulates head partitioning via mocked methods below. @@ -247,11 +264,55 @@ def test_sparse_backend_decode_correctness( seq_lens = batch_spec.seq_lens query_lens = batch_spec.query_lens + # Pre-compute positions and sparse indices for all tokens. + # We need these BEFORE computing the reference to use sparse attention masks. + total_query_tokens = sum(query_lens) + positions = [] + for i in range(batch_spec.batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + ctx_len = s_len - q_len + for q_idx in range(q_len): + positions.append(ctx_len + q_idx) + + # Create sparse indices with UNIQUE per-token offsets to catch bugs where + # the kernel uses wrong indices for some tokens (e.g., due to incorrect + # tensor shapes like [1, num_tokens, ...] instead of [num_tokens, 1, ...]). + # Also include -1 masked indices to verify the kernel handles them correctly. + sparse_indices = torch.empty( + total_query_tokens, topk_tokens, dtype=torch.int32, device=device + ) + for tok_idx in range(total_query_tokens): + max_valid_idx = positions[tok_idx] + offset = tok_idx * 7 # Prime number for varied offsets + # Use only half the topk indices as valid, mask the rest with -1 + # This tests that the kernel correctly ignores -1 indices + num_valid = min(topk_tokens // 2, max_valid_idx + 1) + if num_valid > 0: + valid_range = torch.arange(num_valid, device=device, dtype=torch.int32) + tok_indices = (valid_range + offset) % (max_valid_idx + 1) + # Pad with -1 for the remaining positions + tok_indices = torch.cat( + [ + tok_indices, + torch.full( + (topk_tokens - num_valid,), -1, device=device, dtype=torch.int32 + ), + ] + ) + else: + tok_indices = torch.full( + (topk_tokens,), -1, device=device, dtype=torch.int32 + ) + tok_indices[0] = 0 # At least one valid index + sparse_indices[tok_idx] = tok_indices + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] kv_c_contexts, k_pe_contexts = [], [] reference_outputs = [] kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + global_token_idx = 0 for i in range(batch_spec.batch_size): s_len = seq_lens[i] @@ -268,40 +329,53 @@ def test_sparse_backend_decode_correctness( kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device) k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) - # SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion - # which truncates scales to powers of 2. Simulate this in reference. - is_sm100 = torch.cuda.get_device_capability()[0] >= 10 - kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla( - kv_c_full, - k_pe_full.squeeze(1), - block_size=vllm_config.cache_config.block_size, - scale=kv_cache_scale, - simulate_sm100_e8m0_scales=is_sm100, - ) + if use_fp8_ds_mla_quantization: + is_sm100 = torch.cuda.get_device_capability()[0] >= 10 + kv_c_full, k_pe_squeezed = _quantize_dequantize_fp8_ds_mla( + kv_c_full, + k_pe_full.squeeze(1), + block_size=block_size, + scale=kv_cache_scale, + simulate_sm100_e8m0_scales=is_sm100, + ) + k_pe_full = k_pe_squeezed.unsqueeze(1) q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK) q_mqa = torch.cat([ql_nope, q_pe], dim=-1) - k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1) - k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1) - v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + v_mqa = kv_c_full - attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) - causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) - attn_mask[:, ctx_len:] = causal_mask + # Compute sparse SDPA reference per query token using its sparse indices + for q_idx in range(q_len): + tok_sparse_idx = sparse_indices[global_token_idx] + valid_mask = tok_sparse_idx >= 0 + valid_indices = tok_sparse_idx[valid_mask].long() - q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + q_tok = q_mqa[q_idx : q_idx + 1] # [1, num_heads, head_dim] + k_sparse = k_mqa[valid_indices] # [num_valid, head_dim] + v_sparse = v_mqa[valid_indices] # [num_valid, kv_lora_rank] - sdpa_out = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale - ) - sdpa_out = sdpa_out.transpose(1, 2).squeeze(0) + k_sparse = k_sparse.unsqueeze(1).expand(-1, num_heads, -1) + v_sparse = v_sparse.unsqueeze(1).expand(-1, num_heads, -1) - sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV) - reference_outputs.append(sdpa_out.flatten(start_dim=-2)) + # SDPA: [1, num_heads, 1, head_dim] x [1, num_heads, num_valid, head_dim] + q_sdpa_in = q_tok.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_sparse.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_sparse.unsqueeze(0).transpose(1, 2) + + sdpa_out = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, scale=scale + ) + sdpa_out = sdpa_out.transpose(1, 2).squeeze( + 0 + ) # [1, num_heads, kv_lora_rank] + + sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV) + reference_outputs.append(sdpa_out.flatten(start_dim=-2)) + + global_token_idx += 1 all_q_vllm.append(q_c) all_kv_c_vllm.append(kv_c_full[ctx_len:]) @@ -334,42 +408,18 @@ def test_sparse_backend_decode_correctness( num_blocks=vllm_config.cache_config.num_gpu_blocks, common_attn_metadata=common_attn_metadata, randomize_blocks=False, - kv_cache_dtype=vllm_config.cache_config.cache_dtype, + kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto", scale=kv_cache_scale, ) - builder_cls = FlashMLASparseBackend.get_builder_cls() + builder_cls = backend_cls.get_builder_cls() builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device) metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata ) - starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) - seg_lengths = np.diff(starts) - positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( - starts[:-1], seg_lengths - ) - seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32) - prefix_lengths = seq_lengths - seg_lengths - positions += np.repeat(prefix_lengths, seg_lengths) - - pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32) - topk = metadata.topk_tokens - debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0) - token_positions = pos_gpu.unsqueeze(1) - causal_mask = debug_indices <= token_positions - debug_indices = torch.where( - causal_mask, debug_indices, torch.full_like(debug_indices, -1) - ) - - # FlashMLASparseImpl now reads top-k indices from the indexer-provided - # buffer, so emulate that contract with a simple namespace mock. - debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone() - mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices) - - ok, reason = flashmla.is_flashmla_sparse_supported() - if not ok: - pytest.skip(reason) + # Use the pre-computed sparse_indices for the mock indexer + mock_indexer = SimpleNamespace(topk_indices_buffer=sparse_indices) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) kv_b_proj_weight = kv_b_proj_weight.view( @@ -383,7 +433,7 @@ def test_sparse_backend_decode_correctness( ).to(device=device, dtype=dtype) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) - impl_cls = FlashMLASparseBackend.get_impl_cls() + impl_cls = backend_cls.get_impl_cls() with set_current_vllm_config(vllm_config): impl = impl_cls( num_heads=num_heads, @@ -441,7 +491,7 @@ def test_sparse_backend_decode_correctness( # FP8 quantization introduces some error, but should be within reasonable bounds # BF16 (auto) should be very accurate, FP8 allows slightly more tolerance - if kv_cache_dtype == "fp8_ds_mla": + if kv_cache_dtype.startswith("fp8"): torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05) else: torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01) @@ -636,3 +686,63 @@ def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_s def test_split_prefill_chunks(seq_lens, max_buf, expected): out = split_prefill_chunks(seq_lens, max_buf) assert out == expected + + +def test_triton_convert_returns_valid_counts(): + """Test that return_valid_counts correctly counts non-negative indices.""" + device = torch.device("cuda") + num_tokens = 8 + num_requests = 2 + max_blocks_per_req = 10 + block_size = 64 + num_topk_tokens = 128 + + req_id = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1], dtype=torch.int32, device=device) + block_table = torch.arange( + num_requests * max_blocks_per_req, dtype=torch.int32, device=device + ).view(num_requests, max_blocks_per_req) + + # Create token indices with varying numbers of valid entries + # Token 0: 64 valid, 64 invalid (-1) + # Token 1: 32 valid, 96 invalid + # Token 2: 128 valid (all) + # Token 3: 1 valid, 127 invalid + # etc. + token_indices = torch.full( + (num_tokens, num_topk_tokens), -1, dtype=torch.int32, device=device + ) + expected_valid = [] + for i in range(num_tokens): + num_valid = [64, 32, 128, 1, 64, 32, 128, 1][i] + token_indices[i, :num_valid] = torch.arange( + num_valid, dtype=torch.int32, device=device + ) % (block_size * max_blocks_per_req) + expected_valid.append(num_valid) + + expected_valid_tensor = torch.tensor( + expected_valid, dtype=torch.int32, device=device + ) + + # Test with return_valid_counts=True + result, valid_counts = triton_convert_req_index_to_global_index( + req_id, + block_table, + token_indices, + BLOCK_SIZE=block_size, + NUM_TOPK_TOKENS=num_topk_tokens, + return_valid_counts=True, + ) + + torch.testing.assert_close(valid_counts, expected_valid_tensor, rtol=0, atol=0) + + # Test that return_valid_counts=False returns only the indices + result_only = triton_convert_req_index_to_global_index( + req_id, + block_table, + token_indices, + BLOCK_SIZE=block_size, + NUM_TOPK_TOKENS=num_topk_tokens, + return_valid_counts=False, + ) + assert isinstance(result_only, torch.Tensor) + torch.testing.assert_close(result_only, result, rtol=0, atol=0) diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index eb68deb1b..3aca49f94 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -901,10 +901,50 @@ def parse_cuda_priority_lists() -> dict[str, list[str]]: def _get_backends_from_return(stmts: list) -> list[str]: - """Extract backend names from return statements in a list of statements.""" + """Extract backend names from return statements in a list of statements. + + Handles starred unpacking (e.g. ``*sparse_backends``) by resolving the + variable from assignments found in the same statement list. When the + variable is conditionally assigned (inside an ``if/else``), the ``else`` + branch value is used as the representative default. + """ + # Collect variable assignments so we can resolve starred expressions. + # For conditional assignments, last-written (else branch) wins. + var_assigns: dict[str, list[str]] = {} + for stmt in stmts: + if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.List): + for target in stmt.targets: + if isinstance(target, ast.Name): + var_assigns[target.id] = [ + e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute) + ] + elif isinstance(stmt, ast.If): + for branch in (stmt.body, stmt.orelse): + for branch_stmt in branch: + if isinstance(branch_stmt, ast.Assign) and isinstance( + branch_stmt.value, ast.List + ): + for target in branch_stmt.targets: + if isinstance(target, ast.Name): + var_assigns[target.id] = [ + e.attr + for e in branch_stmt.value.elts + if isinstance(e, ast.Attribute) + ] + for stmt in stmts: if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.List): - return [e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)] + backends: list[str] = [] + for e in stmt.value.elts: + if isinstance(e, ast.Attribute): + backends.append(e.attr) + elif ( + isinstance(e, ast.Starred) + and isinstance(e.value, ast.Name) + and e.value.id in var_assigns + ): + backends.extend(var_assigns[e.value.id]) + return backends return [] diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index c44bf1f16..98ff02e9d 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -334,6 +334,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): block_size, use_mla=True, use_sparse=use_sparse, + num_heads=self.num_heads, ) if ( diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 3edc83b15..b3d6b0ed6 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -129,6 +129,7 @@ class CpuPlatform(Platform): cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, ) -> str: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: logger.info("Cannot use %s backend on CPU.", selected_backend) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0c0bd7db3..b7efe24dc 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -45,17 +45,29 @@ torch.backends.cuda.enable_cudnn_sdp(False) def _get_backend_priorities( use_mla: bool, device_capability: DeviceCapability, + num_heads: int | None = None, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" if use_mla: if device_capability.major == 10: + # Prefer FlashInfer at low head counts (FlashMLA uses padding) + if num_heads is not None and num_heads <= 16: + sparse_backends = [ + AttentionBackendEnum.FLASHINFER_MLA_SPARSE, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + sparse_backends = [ + AttentionBackendEnum.FLASHMLA_SPARSE, + AttentionBackendEnum.FLASHINFER_MLA_SPARSE, + ] return [ AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.FLASH_ATTN_MLA, AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.TRITON_MLA, - AttentionBackendEnum.FLASHMLA_SPARSE, + *sparse_backends, ] else: return [ @@ -182,6 +194,8 @@ class CudaPlatformBase(Platform): use_flashmla = False use_cutlass_mla = False use_flashinfer_mla = False + use_flashmla_sparse = False + use_flashinfer_mla_sparse = False from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported @@ -217,6 +231,10 @@ class CudaPlatformBase(Platform): use_flashmla = backend == AttentionBackendEnum.FLASHMLA use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA + use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE + use_flashinfer_mla_sparse = ( + backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE + ) if ( use_flashmla @@ -242,12 +260,24 @@ class CudaPlatformBase(Platform): "Forcing kv cache block size to 64 for FlashInferMLA backend." ) - # TODO(Chen): remove this hacky code - if use_sparse and cache_config.block_size != 64: - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashMLASparse backend." - ) + if use_sparse: + if not (use_flashmla_sparse or use_flashinfer_mla_sparse): + use_flashmla_sparse = True + + if use_flashmla_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse backend." + ) + elif use_flashinfer_mla_sparse and cache_config.block_size not in ( + 32, + 64, + ): + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashInferMLASparse " + "backend." + ) scheduler_config = vllm_config.scheduler_config # Note: model_config may be None during testing @@ -276,6 +306,7 @@ class CudaPlatformBase(Platform): cls, device_capability: DeviceCapability, attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, ) -> tuple[ list[tuple["AttentionBackendEnum", int]], dict["AttentionBackendEnum", list[str]], @@ -284,7 +315,9 @@ class CudaPlatformBase(Platform): invalid_reasons = {} backend_priorities = _get_backend_priorities( - attn_selector_config.use_mla, device_capability + attn_selector_config.use_mla, + device_capability, + num_heads, ) for priority, backend in enumerate(backend_priorities): try: @@ -307,6 +340,7 @@ class CudaPlatformBase(Platform): cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, ) -> str: device_capability = cls.get_device_capability() assert device_capability is not None @@ -336,6 +370,7 @@ class CudaPlatformBase(Platform): valid_backends_priorities, invalid_reasons = cls.get_valid_backends( device_capability=device_capability, attn_selector_config=attn_selector_config, + num_heads=num_heads, ) reasons_str = ( "{" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 27f5ea517..4595b599b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -233,6 +233,7 @@ class Platform: cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, ) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b463c80a1..808d21400 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -265,6 +265,7 @@ class RocmPlatform(Platform): cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, ) -> str: from vllm._aiter_ops import rocm_aiter_ops diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 3a0ea8b12..8daa2d47f 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -48,6 +48,7 @@ class XPUPlatform(Platform): cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, ) -> str: from vllm.v1.attention.backends.utils import set_kv_cache_layout diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py new file mode 100644 index 000000000..21a0d99c2 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -0,0 +1,353 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FlashInfer MLA Sparse Attention Backend. + +This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k +for models like DeepSeek-V3.2 that use index-based sparse attention. + +For sparse MLA: +- block_tables shape changes from [batch_size, max_num_blocks] (dense) + to [batch_size, q_len_per_request, sparse_mla_top_k] (sparse) +- The sparse indices represent physical cache slot positions to attend to +- sparse_mla_top_k parameter must be set to the topk value +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar + +import numpy as np +import torch +from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla + +from vllm.config import VllmConfig +from vllm.config.cache import CacheDType +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.mla_attention import ( + get_mla_dims, +) +from vllm.platforms.interface import DeviceCapability +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionCGSupport, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + MultipleOf, + SparseMLAAttentionImpl, +) +from vllm.v1.attention.backends.mla.sparse_utils import ( + triton_convert_req_index_to_global_index, +) +from vllm.v1.attention.backends.utils import KVCacheLayoutType +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer + +logger = init_logger(__name__) + +FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + + +class FlashInferMLASparseBackend(AttentionBackend): + """FlashInfer MLA backend with sparse attention support. + + This backend uses the FlashInfer TRT-LLM MLA kernel with sparse_mla_top_k + for models like DeepSeek-V3.2 that use index-based sparse attention. + """ + + accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [32, 64] + + @staticmethod + def get_name() -> str: + return "FLASHINFER_MLA_SPARSE" + + @staticmethod + def get_impl_cls() -> type["FlashInferMLASparseImpl"]: + return FlashInferMLASparseImpl + + @staticmethod + def get_builder_cls() -> type["FlashInferMLASparseMetadataBuilder"]: + return FlashInferMLASparseMetadataBuilder + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + # FlashInfer sparse MLA targets Blackwell (SM 10.x) + return capability.major == 10 + + @classmethod + def supports_combination( + cls, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: CacheDType | None, + block_size: int, + use_mla: bool, + has_sink: bool, + use_sparse: bool, + device_capability: DeviceCapability, + ) -> str | None: + # FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128 + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + if vllm_config.model_config is not None: + hf_text_config = vllm_config.model_config.hf_text_config + qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) + if qk_nope_head_dim != 128: + return ( + f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, " + f"but got {qk_nope_head_dim}" + ) + # Check for index_topk which indicates sparse model + if not hasattr(hf_text_config, "index_topk"): + return "FlashInfer MLA Sparse requires model with index_topk config" + return None + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": + return "HND" + + +@dataclass +class FlashInferMLASparseMetadata(AttentionMetadata): + """Attention metadata for FlashInfer MLA Sparse backend.""" + + num_reqs: int + max_query_len: int + max_seq_len: int + num_actual_tokens: int + + # Query start locations + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + block_table: torch.Tensor + req_id_per_token: torch.Tensor + + # Sequence lengths for all requests (context + query) + seq_lens: torch.Tensor + + # Sparse-specific + block_size: int = 64 + topk_tokens: int = 2048 + + +class FlashInferMLASparseMetadataBuilder( + AttentionMetadataBuilder[FlashInferMLASparseMetadata] +): + """Builder for FlashInfer MLA Sparse attention metadata.""" + + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ) -> None: + self.vllm_config = vllm_config + self.layer_names = layer_names + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + self.device = device + + self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + + self.req_id_per_token_buffer = torch.empty( + (vllm_config.scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> FlashInferMLASparseMetadata: + cm = common_attn_metadata + num_tokens = cm.num_actual_tokens + + # Build req_id_per_token mapping + starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + req_id_per_token_tensor = self.req_id_per_token_buffer[:num_tokens] + + return FlashInferMLASparseMetadata( + num_reqs=cm.num_reqs, + max_query_len=cm.max_query_len, + max_seq_len=cm.max_seq_len, + num_actual_tokens=cm.num_actual_tokens, + query_start_loc=cm.query_start_loc, + slot_mapping=cm.slot_mapping, + block_table=cm.block_table_tensor, + req_id_per_token=req_id_per_token_tensor, + seq_lens=cm.seq_lens, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + ) + + +# Global workspace buffer (lazily initialized) +_fi_sparse_workspace: torch.Tensor | None = None + + +def _get_workspace_buffer(device: torch.device) -> torch.Tensor: + global _fi_sparse_workspace + if _fi_sparse_workspace is None: + _fi_sparse_workspace = torch.zeros( + FLASHINFER_MLA_SPARSE_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device, + ) + return _fi_sparse_workspace + + +class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata]): + """FlashInfer MLA Sparse implementation. + + Uses the TRT-LLM MLA kernel with sparse_mla_top_k parameter for + sparse attention computation. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + topk_indice_buffer: torch.Tensor | None = None, + indexer: "Indexer | None" = None, + **mla_args, + ) -> None: + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashInferMLASparseImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap" + ) + + if attn_type != AttentionType.DECODER: + raise NotImplementedError( + "Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferMLASparseImpl" + ) + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA-specific dimensions + self.kv_lora_rank: int = mla_args["kv_lora_rank"] + self.qk_nope_head_dim: int = mla_args["qk_nope_head_dim"] + self.qk_rope_head_dim: int = mla_args["qk_rope_head_dim"] + + assert indexer is not None, "Indexer required for sparse MLA" + self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer + + self._workspace_buffer: torch.Tensor | None = None + self.bmm1_scale: float | None = None + self.bmm2_scale: float | None = None + + def forward_mqa( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashInferMLASparseMetadata, + layer: AttentionLayer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if isinstance(q, tuple): + q = torch.cat(q, dim=-1) + + num_actual_toks = q.shape[0] + + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + topk_indices_physical, seq_lens = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token[:num_actual_toks], + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=topk_indices.shape[1], + return_valid_counts=True, + ) + + if self._workspace_buffer is None: + self._workspace_buffer = _get_workspace_buffer(q.device) + + if self.bmm1_scale is None: + self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale + if self.bmm2_scale is None: + self.bmm2_scale = layer._v_scale_float + + o = trtllm_batch_decode_with_kv_cache_mla( + query=q.unsqueeze(1), + kv_cache=kv_c_and_k_pe_cache.unsqueeze(1), + workspace_buffer=self._workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=topk_indices_physical.unsqueeze(1), + seq_lens=seq_lens, + max_seq_len=attn_metadata.topk_tokens, + bmm1_scale=self.bmm1_scale, + bmm2_scale=self.bmm2_scale, + sparse_mla_top_k=attn_metadata.topk_tokens, + ) + return o.view(-1, o.shape[-2], o.shape[-1]), None diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 80e402a4d..799c77d73 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability -from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -26,6 +25,9 @@ from vllm.v1.attention.backend import ( MultipleOf, SparseMLAAttentionImpl, ) +from vllm.v1.attention.backends.mla.sparse_utils import ( + triton_convert_req_index_to_global_index, +) from vllm.v1.attention.backends.utils import ( reshape_attn_output_for_spec_decode, reshape_query_for_spec_decode, @@ -203,166 +205,6 @@ class FlashMLASparseMetadata(AttentionMetadata): fp8_use_mixed_batch: bool = False -# Kernel with prefill workspace support -@triton.jit -def _convert_req_index_to_global_index_kernel( - req_id_ptr, # int32 [num_tokens] - block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] - token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] - prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill - workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr - # shapes (compile-time where possible) - max_num_blocks_per_req: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BLOCK_N: tl.constexpr, # tile width along columns - HAS_PREFILL: tl.constexpr, - # strides (in elements) - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, -): - # program_id(0) -> token_id (row) - # program_id(1) -> tile index along columns - token_id = tl.program_id(0) - tile_id = tl.program_id(1) - - # Each program covers BLOCK_N consecutive columns - indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) - - # Load request id for this token (no mask: grid is exact) - req = tl.load(req_id_ptr + token_id) - - # Load token indices for this tile - ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 - tok = tl.load(ti_ptr) # int32 - - # Only token == -1 should propagate as -1 - is_invalid_tok = tok < 0 - is_prefill = False - if HAS_PREFILL: - prefill_req_id = tl.load(prefill_request_id_ptr + token_id) - is_prefill = prefill_req_id >= 0 - # Compute block id and in-block offset - block_id = tok // BLOCK_SIZE - inblock_off = tok % BLOCK_SIZE - - # Guard block_table access - valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) - bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 - is_invalid_tok |= ~valid_block - base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) - out_val = base * BLOCK_SIZE + inblock_off - - # Override with prefill output if prefill is enabled - if HAS_PREFILL: - workspace_start = tl.load( - workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 - ) - prefill_out = workspace_start + tok - out_val = tl.where(is_prefill, prefill_out, out_val) - out_val = tl.where(is_invalid_tok, -1, out_val) - - # Store results - out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 - tl.store(out_ptr_ij, out_val) - - -def triton_convert_req_index_to_global_index( - req_id: torch.Tensor, # int32 [num_tokens] - block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] - token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] - BLOCK_SIZE: int = 64, - NUM_TOPK_TOKENS: int = 2048, - BLOCK_N: int = 128, # tile width along columns - HAS_PREFILL_WORKSPACE: bool = False, - prefill_workspace_request_ids: torch.Tensor | None = None, - prefill_workspace_starts: torch.Tensor | None = None, -): - """ - out[token_id, indice_id] = - block_table[req_id[token_id], - token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE - + token_indices[token_id, indice_id] % BLOCK_SIZE - - Only when token_indices[token_id, indice_id] == -1 do we output -1. - For safety, we also output -1 if the derived block_id would be - out-of-bounds. - - When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets - instead of global cache slots. prefill_workspace_request_ids and - prefill_workspace_starts must be provided. - - prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else - prefill request index (maps to prefill_workspace_starts) - prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace - starts for each prefill request - """ - assert req_id.dtype == torch.int32 - assert block_table.dtype == torch.int32 - assert token_indices.dtype == torch.int32 - assert token_indices.shape[1] == NUM_TOPK_TOKENS - assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( - f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" - ) - - if HAS_PREFILL_WORKSPACE: - assert prefill_workspace_request_ids is not None - assert prefill_workspace_starts is not None - assert prefill_workspace_request_ids.dtype == torch.int32 - assert prefill_workspace_starts.dtype == torch.int32 - - num_tokens = req_id.shape[0] - max_num_blocks_per_req = block_table.shape[1] - tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N - - # Ensure contiguous tensors on the same device - req_id_c = req_id.contiguous() - block_table_c = block_table.contiguous() - token_indices_c = token_indices.contiguous() - out = torch.empty_like(token_indices_c) - - # Strides in elements - bt_stride0, bt_stride1 = block_table_c.stride() - ti_stride0, ti_stride1 = token_indices_c.stride() - out_stride0, out_stride1 = out.stride() - - # Prepare prefill pointers - if HAS_PREFILL_WORKSPACE: - assert prefill_workspace_request_ids is not None # for mypy - assert prefill_workspace_starts is not None # for mypy - assert prefill_workspace_request_ids.is_contiguous() - assert prefill_workspace_starts.is_contiguous() - - # Exact 2D grid: tokens × column tiles - grid = (num_tokens, tiles_per_row) - - _convert_req_index_to_global_index_kernel[grid]( - req_id_c, - block_table_c, - token_indices_c, - out, - prefill_workspace_request_ids, - prefill_workspace_starts, - # shapes / constexprs - max_num_blocks_per_req, - BLOCK_SIZE, - BLOCK_N, - HAS_PREFILL_WORKSPACE, - # strides - bt_stride0, - bt_stride1, - ti_stride0, - ti_stride1, - out_stride0, - out_stride1, - ) - return out - - def get_prefill_workspace_size(max_model_len: int): # NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size. # May be tuned later. diff --git a/vllm/v1/attention/backends/mla/sparse_utils.py b/vllm/v1/attention/backends/mla/sparse_utils.py new file mode 100644 index 000000000..e4bd0cf42 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_utils.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions for sparse MLA backends.""" + +import torch + +from vllm.triton_utils import tl, triton + + +# Kernel with prefill workspace support and valid count tracking +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + valid_count_ptr, # int32 [num_tokens] - output valid count per row + prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill + workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + HAS_PREFILL: tl.constexpr, + COUNT_VALID: tl.constexpr, # whether to count valid indices + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + is_prefill = False + if HAS_PREFILL: + prefill_req_id = tl.load(prefill_request_id_ptr + token_id) + is_prefill = prefill_req_id >= 0 + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + is_invalid_tok |= ~valid_block + base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0) + out_val = base * BLOCK_SIZE + inblock_off + + # Override with prefill output if prefill is enabled + if HAS_PREFILL: + workspace_start = tl.load( + workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0 + ) + prefill_out = workspace_start + tok + out_val = tl.where(is_prefill, prefill_out, out_val) + out_val = tl.where(is_invalid_tok, -1, out_val) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + # Count valid indices in this tile and atomically add to row total + if COUNT_VALID: + tile_valid_count = tl.sum((~is_invalid_tok).to(tl.int32)) + tl.atomic_add(valid_count_ptr + token_id, tile_valid_count) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns + HAS_PREFILL_WORKSPACE: bool = False, + prefill_workspace_request_ids: torch.Tensor | None = None, + prefill_workspace_starts: torch.Tensor | None = None, + return_valid_counts: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + + When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets + instead of global cache slots. prefill_workspace_request_ids and + prefill_workspace_starts must be provided. + + prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else + prefill request index (maps to prefill_workspace_starts) + prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace + starts for each prefill request + + When return_valid_counts is True, also returns the count of valid (non -1) + indices per row, computed during the same kernel pass (no extra overhead). + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, ( + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})" + ) + + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None + assert prefill_workspace_starts is not None + assert prefill_workspace_request_ids.dtype == torch.int32 + assert prefill_workspace_starts.dtype == torch.int32 + + num_tokens = req_id.shape[0] + max_num_blocks_per_req = block_table.shape[1] + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Allocate valid count buffer if needed (must be zero-initialized for atomics) + valid_counts: torch.Tensor | None = None + if return_valid_counts: + valid_counts = torch.zeros( + num_tokens, dtype=torch.int32, device=token_indices.device + ) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Prepare prefill pointers + if HAS_PREFILL_WORKSPACE: + assert prefill_workspace_request_ids is not None # for mypy + assert prefill_workspace_starts is not None # for mypy + assert prefill_workspace_request_ids.is_contiguous() + assert prefill_workspace_starts.is_contiguous() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + valid_counts, + prefill_workspace_request_ids, + prefill_workspace_starts, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + HAS_PREFILL_WORKSPACE, + return_valid_counts, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + + if return_valid_counts: + assert valid_counts is not None + return out, valid_counts + return out diff --git a/vllm/v1/attention/backends/registry.py b/vllm/v1/attention/backends/registry.py index 2a80bbd94..8e60551e2 100644 --- a/vllm/v1/attention/backends/registry.py +++ b/vllm/v1/attention/backends/registry.py @@ -62,6 +62,10 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): FLASHINFER_MLA = ( "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend" ) + FLASHINFER_MLA_SPARSE = ( + "vllm.v1.attention.backends.mla.flashinfer_mla_sparse." + "FlashInferMLASparseBackend" + ) TRITON_MLA = "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend" CUTLASS_MLA = "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend" FLASHMLA = "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend" diff --git a/vllm/v1/attention/selector.py b/vllm/v1/attention/selector.py index e364c3235..9580c1d5f 100644 --- a/vllm/v1/attention/selector.py +++ b/vllm/v1/attention/selector.py @@ -53,6 +53,7 @@ def get_attn_backend( use_sparse: bool = False, use_mm_prefix: bool = False, attn_type: str | None = None, + num_heads: int | None = None, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -66,7 +67,6 @@ def get_attn_backend( from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() - backend_enum = vllm_config.attention_config.backend attn_selector_config = AttentionSelectorConfig( head_size=head_size, @@ -81,8 +81,9 @@ def get_attn_backend( ) return _cached_get_attn_backend( - backend=backend_enum, + backend=vllm_config.attention_config.backend, attn_selector_config=attn_selector_config, + num_heads=num_heads, ) @@ -90,12 +91,14 @@ def get_attn_backend( def _cached_get_attn_backend( backend, attn_selector_config: AttentionSelectorConfig, + num_heads: int | None = None, ) -> type[AttentionBackend]: from vllm.platforms import current_platform attention_cls = current_platform.get_attn_backend_cls( backend, attn_selector_config=attn_selector_config, + num_heads=num_heads, ) if not attention_cls: raise ValueError(