[Attention] Add FlashInfer Sparse MLA backend (#33451)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-02-12 12:21:54 -05:00
committed by GitHub
parent 334c715e0f
commit f2c47886fd
24 changed files with 1181 additions and 408 deletions

View File

@@ -8,14 +8,13 @@ This module provides helpers for running MLA backends without
needing full VllmConfig integration.
"""
import importlib
import numpy as np
import torch
from batch_spec import parse_batch_spec
from common import (
BenchmarkResult,
MockHfConfig,
MockIndexer,
MockKVBProj,
MockLayer,
setup_mla_dims,
@@ -62,6 +61,7 @@ def create_minimal_vllm_config(
block_size: int = 128,
max_num_seqs: int = 256,
mla_dims: dict | None = None,
index_topk: int | None = None,
) -> VllmConfig:
"""
Create minimal VllmConfig for MLA benchmarks.
@@ -73,6 +73,8 @@ def create_minimal_vllm_config(
max_num_seqs: Maximum number of sequences
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
setup_mla_dims(model_name)
index_topk: Optional topk value for sparse MLA backends. If provided,
the config will include index_topk for sparse attention.
Returns:
VllmConfig for benchmarking
@@ -82,7 +84,7 @@ def create_minimal_vllm_config(
mla_dims = setup_mla_dims(model_name)
# Create mock HF config first (avoids downloading from HuggingFace)
mock_hf_config = MockHfConfig(mla_dims)
mock_hf_config = MockHfConfig(mla_dims, index_topk=index_topk)
# Create a temporary minimal config.json to avoid HF downloads
# This ensures consistent ModelConfig construction without network access
@@ -120,16 +122,12 @@ def create_minimal_vllm_config(
seed=0,
max_model_len=32768,
quantization=None,
quantization_param_path=None,
enforce_eager=False,
max_context_len_to_capture=None,
max_seq_len_to_capture=8192,
max_logprobs=20,
disable_sliding_window=False,
skip_tokenizer_init=True,
served_model_name=None,
limit_mm_per_prompt=None,
use_async_output_proc=True,
config_format="auto",
)
finally:
@@ -180,56 +178,65 @@ def create_minimal_vllm_config(
# ============================================================================
# Backend name to class name prefix mapping
_BACKEND_NAME_MAP = {
"flashattn_mla": "FlashAttnMLA",
"flashmla": "FlashMLA",
"flashinfer_mla": "FlashInferMLA",
"cutlass_mla": "CutlassMLA",
}
# Special properties that differ from defaults
# Backend-specific properties that can't be inferred from the backend class
# Keys are AttentionBackendEnum names (uppercase)
_BACKEND_PROPERTIES = {
"flashmla": {
"FLASHMLA": {
"query_format": "concat", # Single concatenated tensor (vs tuple)
"block_size": 64, # FlashMLA uses fixed block size
},
"flashinfer_mla": {
"block_size": 64, # FlashInfer MLA only supports 32 or 64
"FLASHMLA_SPARSE": {
"query_format": "concat", # Single concatenated tensor (vs tuple)
},
}
def _get_backend_config(backend: str) -> dict:
"""
Get backend configuration using naming conventions.
Get backend configuration from AttentionBackendEnum.
All MLA backends follow the pattern:
- Module: vllm.v1.attention.backends.mla.{backend}
- Impl: {Name}Impl
- Metadata: {Name}Metadata (or MLACommonMetadata)
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
- MetadataBuilder: {Name}MetadataBuilder
Uses the registry to get the backend class and extract configuration
from its methods (get_impl_cls, get_builder_cls, is_sparse, etc.).
Args:
backend: Backend name matching AttentionBackendEnum exactly
(e.g., "FLASHMLA_SPARSE")
Returns:
Dict with backend configuration
"""
if backend not in _BACKEND_NAME_MAP:
raise ValueError(f"Unknown backend: {backend}")
from vllm.v1.attention.backends.registry import AttentionBackendEnum
name = _BACKEND_NAME_MAP[backend]
try:
backend_enum = AttentionBackendEnum[backend]
backend_class = backend_enum.get_class()
except (KeyError, ValueError) as e:
valid_backends = [e.name for e in AttentionBackendEnum if e.name != "CUSTOM"]
raise ValueError(
f"Unknown backend: {backend}. "
f"Valid MLA backends: {[b for b in valid_backends if 'MLA' in b]}"
) from e
# Get block size from backend class
block_sizes = backend_class.get_supported_kernel_block_sizes()
# Use first supported block size (backends typically support one for MLA)
block_size = block_sizes[0] if block_sizes else None
if hasattr(block_size, "value"):
# Handle MultipleOf enum
block_size = None
# Check if sparse via class method if available
is_sparse = getattr(backend_class, "is_sparse", lambda: False)()
# Get properties that can't be inferred
props = _BACKEND_PROPERTIES.get(backend, {})
# Check if backend uses common metadata (FlashInfer, CUTLASS)
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
return {
"module": f"vllm.v1.attention.backends.mla.{backend}",
"impl_class": f"{name}Impl",
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
"decode_metadata_class": "MLACommonDecodeMetadata"
if uses_common
else f"{name}DecodeMetadata",
"builder_class": f"{name}MetadataBuilder",
"backend_class": backend_class,
"impl_class": backend_class.get_impl_cls(),
"builder_class": backend_class.get_builder_cls(),
"query_format": props.get("query_format", "tuple"),
"block_size": props.get("block_size", None),
"block_size": block_size,
"is_sparse": is_sparse,
}
@@ -447,22 +454,26 @@ def _create_backend_impl(
mla_dims: dict,
vllm_config: VllmConfig,
device: torch.device,
max_num_tokens: int = 8192,
index_topk: int | None = None,
):
"""
Create backend implementation instance.
Args:
backend_cfg: Backend configuration dict
backend_cfg: Backend configuration dict from _get_backend_config()
mla_dims: MLA dimension configuration
vllm_config: VllmConfig instance
device: Target device
max_num_tokens: Maximum number of tokens for sparse indexer buffer
index_topk: Topk value for sparse MLA backends
Returns:
Tuple of (impl, layer, builder_instance)
Tuple of (impl, layer, builder_instance, indexer)
"""
# Import backend classes
backend_module = importlib.import_module(backend_cfg["module"])
impl_class = getattr(backend_module, backend_cfg["impl_class"])
# Get classes from backend config (already resolved by _get_backend_config)
impl_class = backend_cfg["impl_class"]
builder_class = backend_cfg["builder_class"]
# Calculate scale
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
@@ -474,26 +485,44 @@ def _create_backend_impl(
v_head_dim=mla_dims["v_head_dim"],
)
# Create indexer for sparse backends
indexer = None
if backend_cfg.get("is_sparse", False):
if index_topk is None:
index_topk = 2048 # Default topk for sparse MLA
indexer = MockIndexer(
max_num_tokens=max_num_tokens,
topk_tokens=index_topk,
device=device,
)
# Build impl kwargs
impl_kwargs = {
"num_heads": mla_dims["num_q_heads"],
"head_size": mla_dims["head_dim"],
"scale": scale,
"num_kv_heads": mla_dims["num_kv_heads"],
"alibi_slopes": None,
"sliding_window": None,
"kv_cache_dtype": "auto",
"logits_soft_cap": None,
"attn_type": "decoder",
"kv_sharing_target_layer_name": None,
"q_lora_rank": None,
"kv_lora_rank": mla_dims["kv_lora_rank"],
"qk_nope_head_dim": mla_dims["qk_nope_head_dim"],
"qk_rope_head_dim": mla_dims["qk_rope_head_dim"],
"qk_head_dim": mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
"v_head_dim": mla_dims["v_head_dim"],
"kv_b_proj": mock_kv_b_proj,
}
# Add indexer for sparse backends
if indexer is not None:
impl_kwargs["indexer"] = indexer
# Create impl
impl = impl_class(
num_heads=mla_dims["num_q_heads"],
head_size=mla_dims["head_dim"],
scale=scale,
num_kv_heads=mla_dims["num_kv_heads"],
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=mla_dims["kv_lora_rank"],
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
v_head_dim=mla_dims["v_head_dim"],
kv_b_proj=mock_kv_b_proj,
)
impl = impl_class(**impl_kwargs)
# Initialize DCP attributes
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
@@ -515,9 +544,7 @@ def _create_backend_impl(
# Create builder instance if needed
builder_instance = None
if backend_cfg["builder_class"]:
builder_class = getattr(backend_module, backend_cfg["builder_class"])
if builder_class:
# Populate static_forward_context so builder can find the layer
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
@@ -529,7 +556,7 @@ def _create_backend_impl(
device=device,
)
return impl, layer, builder_instance
return impl, layer, builder_instance, indexer
# ============================================================================
@@ -594,6 +621,7 @@ def _run_single_benchmark(
backend_cfg: dict,
mla_dims: dict,
device: torch.device,
indexer=None,
) -> BenchmarkResult:
"""
Run a single benchmark iteration.
@@ -606,6 +634,7 @@ def _run_single_benchmark(
backend_cfg: Backend configuration dict
mla_dims: MLA dimension configuration
device: Target device
indexer: Optional MockIndexer for sparse backends
Returns:
BenchmarkResult with timing statistics
@@ -613,7 +642,9 @@ def _run_single_benchmark(
# Parse batch spec
requests = parse_batch_spec(config.batch_spec)
q_lens = [r.q_len for r in requests]
kv_lens = [r.kv_len for r in requests]
total_q = sum(q_lens)
max_kv_len = max(kv_lens)
# Determine block size
block_size = backend_cfg["block_size"] or config.block_size
@@ -641,8 +672,16 @@ def _run_single_benchmark(
torch.bfloat16,
)
# Determine which forward method to use based on metadata
if metadata.decode is not None:
# Fill indexer with random indices for sparse backends
is_sparse = backend_cfg.get("is_sparse", False)
if is_sparse and indexer is not None:
indexer.fill_random_indices(total_q, max_kv_len)
# Determine which forward method to use
if is_sparse:
# Sparse backends use forward_mqa
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
elif metadata.decode is not None:
forward_fn = lambda: impl._forward_decode(
decode_inputs, kv_cache, metadata, layer
)
@@ -693,11 +732,13 @@ def _run_single_benchmark(
def _run_mla_benchmark_batched(
backend: str,
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
index_topk: int = 2048,
) -> list[BenchmarkResult]:
"""
Unified batched MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
flashinfer_mla_sparse, flashmla_sparse
This function reuses backend initialization across multiple benchmarks
to avoid setup/teardown overhead.
@@ -707,6 +748,7 @@ def _run_mla_benchmark_batched(
configs_with_params: List of (config, threshold, num_splits) tuples
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
- num_splits: num_kv_splits (CUTLASS only)
index_topk: Topk value for sparse MLA backends (default 2048)
Returns:
List of BenchmarkResult objects
@@ -730,19 +772,27 @@ def _run_mla_benchmark_batched(
if mla_dims is None:
mla_dims = setup_mla_dims("deepseek-v3")
# Determine if this is a sparse backend
is_sparse = backend_cfg.get("is_sparse", False)
# Create and set vLLM config for MLA (reused across all benchmarks)
vllm_config = create_minimal_vllm_config(
model_name="deepseek-v3", # Used only for model path
block_size=block_size,
mla_dims=mla_dims, # Use custom dims from config or default
index_topk=index_topk if is_sparse else None,
)
results = []
with set_current_vllm_config(vllm_config):
# Create backend impl, layer, and builder (reused across benchmarks)
impl, layer, builder_instance = _create_backend_impl(
backend_cfg, mla_dims, vllm_config, device
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
impl, layer, builder_instance, indexer = _create_backend_impl(
backend_cfg,
mla_dims,
vllm_config,
device,
index_topk=index_topk if is_sparse else None,
)
# Run each benchmark with the shared impl
@@ -768,6 +818,7 @@ def _run_mla_benchmark_batched(
backend_cfg,
mla_dims,
device,
indexer=indexer,
)
results.append(result)
@@ -793,20 +844,24 @@ def run_mla_benchmark(
config,
reorder_batch_threshold: int | None = None,
num_kv_splits: int | None = None,
index_topk: int = 2048,
) -> BenchmarkResult | list[BenchmarkResult]:
"""
Unified MLA benchmark runner for all backends.
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
flashinfer_mla_sparse, flashmla_sparse
Always uses batched execution internally for optimal performance.
Args:
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
flashinfer_mla_sparse, flashmla_sparse)
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
(single config mode only)
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
index_topk: Topk value for sparse MLA backends (default 2048)
Returns:
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
@@ -816,9 +871,9 @@ def run_mla_benchmark(
# Already in batched format
if len(config) > 0 and isinstance(config[0], tuple):
# Format: [(cfg, param), ...] where param is threshold or num_splits
if backend in ("flashattn_mla", "flashmla"):
if backend in ("flashattn_mla", "flashmla", "flashmla_sparse"):
configs_with_params = [(cfg, param, None) for cfg, param in config]
else: # cutlass_mla or flashinfer_mla
else: # cutlass_mla, flashinfer_mla, or sparse backends
configs_with_params = [(cfg, None, param) for cfg, param in config]
else:
# Format: [cfg, ...] - just configs
@@ -830,7 +885,7 @@ def run_mla_benchmark(
return_single = True
# Use unified batched execution
results = _run_mla_benchmark_batched(backend, configs_with_params)
results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk)
# Return single result or list based on input
return results[0] if return_single else results