[Attention] Add FlashInfer Sparse MLA backend (#33451)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -8,14 +8,13 @@ This module provides helpers for running MLA backends without
|
||||
needing full VllmConfig integration.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from batch_spec import parse_batch_spec
|
||||
from common import (
|
||||
BenchmarkResult,
|
||||
MockHfConfig,
|
||||
MockIndexer,
|
||||
MockKVBProj,
|
||||
MockLayer,
|
||||
setup_mla_dims,
|
||||
@@ -62,6 +61,7 @@ def create_minimal_vllm_config(
|
||||
block_size: int = 128,
|
||||
max_num_seqs: int = 256,
|
||||
mla_dims: dict | None = None,
|
||||
index_topk: int | None = None,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create minimal VllmConfig for MLA benchmarks.
|
||||
@@ -73,6 +73,8 @@ def create_minimal_vllm_config(
|
||||
max_num_seqs: Maximum number of sequences
|
||||
mla_dims: Optional custom MLA dimensions dict. If not provided, uses
|
||||
setup_mla_dims(model_name)
|
||||
index_topk: Optional topk value for sparse MLA backends. If provided,
|
||||
the config will include index_topk for sparse attention.
|
||||
|
||||
Returns:
|
||||
VllmConfig for benchmarking
|
||||
@@ -82,7 +84,7 @@ def create_minimal_vllm_config(
|
||||
mla_dims = setup_mla_dims(model_name)
|
||||
|
||||
# Create mock HF config first (avoids downloading from HuggingFace)
|
||||
mock_hf_config = MockHfConfig(mla_dims)
|
||||
mock_hf_config = MockHfConfig(mla_dims, index_topk=index_topk)
|
||||
|
||||
# Create a temporary minimal config.json to avoid HF downloads
|
||||
# This ensures consistent ModelConfig construction without network access
|
||||
@@ -120,16 +122,12 @@ def create_minimal_vllm_config(
|
||||
seed=0,
|
||||
max_model_len=32768,
|
||||
quantization=None,
|
||||
quantization_param_path=None,
|
||||
enforce_eager=False,
|
||||
max_context_len_to_capture=None,
|
||||
max_seq_len_to_capture=8192,
|
||||
max_logprobs=20,
|
||||
disable_sliding_window=False,
|
||||
skip_tokenizer_init=True,
|
||||
served_model_name=None,
|
||||
limit_mm_per_prompt=None,
|
||||
use_async_output_proc=True,
|
||||
config_format="auto",
|
||||
)
|
||||
finally:
|
||||
@@ -180,56 +178,65 @@ def create_minimal_vllm_config(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Backend name to class name prefix mapping
|
||||
_BACKEND_NAME_MAP = {
|
||||
"flashattn_mla": "FlashAttnMLA",
|
||||
"flashmla": "FlashMLA",
|
||||
"flashinfer_mla": "FlashInferMLA",
|
||||
"cutlass_mla": "CutlassMLA",
|
||||
}
|
||||
|
||||
# Special properties that differ from defaults
|
||||
# Backend-specific properties that can't be inferred from the backend class
|
||||
# Keys are AttentionBackendEnum names (uppercase)
|
||||
_BACKEND_PROPERTIES = {
|
||||
"flashmla": {
|
||||
"FLASHMLA": {
|
||||
"query_format": "concat", # Single concatenated tensor (vs tuple)
|
||||
"block_size": 64, # FlashMLA uses fixed block size
|
||||
},
|
||||
"flashinfer_mla": {
|
||||
"block_size": 64, # FlashInfer MLA only supports 32 or 64
|
||||
"FLASHMLA_SPARSE": {
|
||||
"query_format": "concat", # Single concatenated tensor (vs tuple)
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_config(backend: str) -> dict:
|
||||
"""
|
||||
Get backend configuration using naming conventions.
|
||||
Get backend configuration from AttentionBackendEnum.
|
||||
|
||||
All MLA backends follow the pattern:
|
||||
- Module: vllm.v1.attention.backends.mla.{backend}
|
||||
- Impl: {Name}Impl
|
||||
- Metadata: {Name}Metadata (or MLACommonMetadata)
|
||||
- DecodeMetadata: {Name}DecodeMetadata (or MLACommonDecodeMetadata)
|
||||
- MetadataBuilder: {Name}MetadataBuilder
|
||||
Uses the registry to get the backend class and extract configuration
|
||||
from its methods (get_impl_cls, get_builder_cls, is_sparse, etc.).
|
||||
|
||||
Args:
|
||||
backend: Backend name matching AttentionBackendEnum exactly
|
||||
(e.g., "FLASHMLA_SPARSE")
|
||||
|
||||
Returns:
|
||||
Dict with backend configuration
|
||||
"""
|
||||
if backend not in _BACKEND_NAME_MAP:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
name = _BACKEND_NAME_MAP[backend]
|
||||
try:
|
||||
backend_enum = AttentionBackendEnum[backend]
|
||||
backend_class = backend_enum.get_class()
|
||||
except (KeyError, ValueError) as e:
|
||||
valid_backends = [e.name for e in AttentionBackendEnum if e.name != "CUSTOM"]
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Valid MLA backends: {[b for b in valid_backends if 'MLA' in b]}"
|
||||
) from e
|
||||
|
||||
# Get block size from backend class
|
||||
block_sizes = backend_class.get_supported_kernel_block_sizes()
|
||||
# Use first supported block size (backends typically support one for MLA)
|
||||
block_size = block_sizes[0] if block_sizes else None
|
||||
if hasattr(block_size, "value"):
|
||||
# Handle MultipleOf enum
|
||||
block_size = None
|
||||
|
||||
# Check if sparse via class method if available
|
||||
is_sparse = getattr(backend_class, "is_sparse", lambda: False)()
|
||||
|
||||
# Get properties that can't be inferred
|
||||
props = _BACKEND_PROPERTIES.get(backend, {})
|
||||
|
||||
# Check if backend uses common metadata (FlashInfer, CUTLASS)
|
||||
uses_common = backend in ("flashinfer_mla", "cutlass_mla")
|
||||
|
||||
return {
|
||||
"module": f"vllm.v1.attention.backends.mla.{backend}",
|
||||
"impl_class": f"{name}Impl",
|
||||
"metadata_class": "MLACommonMetadata" if uses_common else f"{name}Metadata",
|
||||
"decode_metadata_class": "MLACommonDecodeMetadata"
|
||||
if uses_common
|
||||
else f"{name}DecodeMetadata",
|
||||
"builder_class": f"{name}MetadataBuilder",
|
||||
"backend_class": backend_class,
|
||||
"impl_class": backend_class.get_impl_cls(),
|
||||
"builder_class": backend_class.get_builder_cls(),
|
||||
"query_format": props.get("query_format", "tuple"),
|
||||
"block_size": props.get("block_size", None),
|
||||
"block_size": block_size,
|
||||
"is_sparse": is_sparse,
|
||||
}
|
||||
|
||||
|
||||
@@ -447,22 +454,26 @@ def _create_backend_impl(
|
||||
mla_dims: dict,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
max_num_tokens: int = 8192,
|
||||
index_topk: int | None = None,
|
||||
):
|
||||
"""
|
||||
Create backend implementation instance.
|
||||
|
||||
Args:
|
||||
backend_cfg: Backend configuration dict
|
||||
backend_cfg: Backend configuration dict from _get_backend_config()
|
||||
mla_dims: MLA dimension configuration
|
||||
vllm_config: VllmConfig instance
|
||||
device: Target device
|
||||
max_num_tokens: Maximum number of tokens for sparse indexer buffer
|
||||
index_topk: Topk value for sparse MLA backends
|
||||
|
||||
Returns:
|
||||
Tuple of (impl, layer, builder_instance)
|
||||
Tuple of (impl, layer, builder_instance, indexer)
|
||||
"""
|
||||
# Import backend classes
|
||||
backend_module = importlib.import_module(backend_cfg["module"])
|
||||
impl_class = getattr(backend_module, backend_cfg["impl_class"])
|
||||
# Get classes from backend config (already resolved by _get_backend_config)
|
||||
impl_class = backend_cfg["impl_class"]
|
||||
builder_class = backend_cfg["builder_class"]
|
||||
|
||||
# Calculate scale
|
||||
scale = 1.0 / np.sqrt(mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"])
|
||||
@@ -474,26 +485,44 @@ def _create_backend_impl(
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
)
|
||||
|
||||
# Create indexer for sparse backends
|
||||
indexer = None
|
||||
if backend_cfg.get("is_sparse", False):
|
||||
if index_topk is None:
|
||||
index_topk = 2048 # Default topk for sparse MLA
|
||||
indexer = MockIndexer(
|
||||
max_num_tokens=max_num_tokens,
|
||||
topk_tokens=index_topk,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Build impl kwargs
|
||||
impl_kwargs = {
|
||||
"num_heads": mla_dims["num_q_heads"],
|
||||
"head_size": mla_dims["head_dim"],
|
||||
"scale": scale,
|
||||
"num_kv_heads": mla_dims["num_kv_heads"],
|
||||
"alibi_slopes": None,
|
||||
"sliding_window": None,
|
||||
"kv_cache_dtype": "auto",
|
||||
"logits_soft_cap": None,
|
||||
"attn_type": "decoder",
|
||||
"kv_sharing_target_layer_name": None,
|
||||
"q_lora_rank": None,
|
||||
"kv_lora_rank": mla_dims["kv_lora_rank"],
|
||||
"qk_nope_head_dim": mla_dims["qk_nope_head_dim"],
|
||||
"qk_rope_head_dim": mla_dims["qk_rope_head_dim"],
|
||||
"qk_head_dim": mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
|
||||
"v_head_dim": mla_dims["v_head_dim"],
|
||||
"kv_b_proj": mock_kv_b_proj,
|
||||
}
|
||||
|
||||
# Add indexer for sparse backends
|
||||
if indexer is not None:
|
||||
impl_kwargs["indexer"] = indexer
|
||||
|
||||
# Create impl
|
||||
impl = impl_class(
|
||||
num_heads=mla_dims["num_q_heads"],
|
||||
head_size=mla_dims["head_dim"],
|
||||
scale=scale,
|
||||
num_kv_heads=mla_dims["num_kv_heads"],
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=mla_dims["kv_lora_rank"],
|
||||
qk_nope_head_dim=mla_dims["qk_nope_head_dim"],
|
||||
qk_rope_head_dim=mla_dims["qk_rope_head_dim"],
|
||||
qk_head_dim=mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"],
|
||||
v_head_dim=mla_dims["v_head_dim"],
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
)
|
||||
impl = impl_class(**impl_kwargs)
|
||||
|
||||
# Initialize DCP attributes
|
||||
if not hasattr(impl, "dcp_world_size") or impl.dcp_world_size in (None, -1):
|
||||
@@ -515,9 +544,7 @@ def _create_backend_impl(
|
||||
|
||||
# Create builder instance if needed
|
||||
builder_instance = None
|
||||
if backend_cfg["builder_class"]:
|
||||
builder_class = getattr(backend_module, backend_cfg["builder_class"])
|
||||
|
||||
if builder_class:
|
||||
# Populate static_forward_context so builder can find the layer
|
||||
# MockLayer inherits from AttentionLayerBase, so isinstance checks pass
|
||||
vllm_config.compilation_config.static_forward_context = {"placeholder": layer}
|
||||
@@ -529,7 +556,7 @@ def _create_backend_impl(
|
||||
device=device,
|
||||
)
|
||||
|
||||
return impl, layer, builder_instance
|
||||
return impl, layer, builder_instance, indexer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -594,6 +621,7 @@ def _run_single_benchmark(
|
||||
backend_cfg: dict,
|
||||
mla_dims: dict,
|
||||
device: torch.device,
|
||||
indexer=None,
|
||||
) -> BenchmarkResult:
|
||||
"""
|
||||
Run a single benchmark iteration.
|
||||
@@ -606,6 +634,7 @@ def _run_single_benchmark(
|
||||
backend_cfg: Backend configuration dict
|
||||
mla_dims: MLA dimension configuration
|
||||
device: Target device
|
||||
indexer: Optional MockIndexer for sparse backends
|
||||
|
||||
Returns:
|
||||
BenchmarkResult with timing statistics
|
||||
@@ -613,7 +642,9 @@ def _run_single_benchmark(
|
||||
# Parse batch spec
|
||||
requests = parse_batch_spec(config.batch_spec)
|
||||
q_lens = [r.q_len for r in requests]
|
||||
kv_lens = [r.kv_len for r in requests]
|
||||
total_q = sum(q_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
|
||||
# Determine block size
|
||||
block_size = backend_cfg["block_size"] or config.block_size
|
||||
@@ -641,8 +672,16 @@ def _run_single_benchmark(
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
# Determine which forward method to use based on metadata
|
||||
if metadata.decode is not None:
|
||||
# Fill indexer with random indices for sparse backends
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
if is_sparse and indexer is not None:
|
||||
indexer.fill_random_indices(total_q, max_kv_len)
|
||||
|
||||
# Determine which forward method to use
|
||||
if is_sparse:
|
||||
# Sparse backends use forward_mqa
|
||||
forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer)
|
||||
elif metadata.decode is not None:
|
||||
forward_fn = lambda: impl._forward_decode(
|
||||
decode_inputs, kv_cache, metadata, layer
|
||||
)
|
||||
@@ -693,11 +732,13 @@ def _run_single_benchmark(
|
||||
def _run_mla_benchmark_batched(
|
||||
backend: str,
|
||||
configs_with_params: list[tuple], # [(config, threshold, num_splits), ...]
|
||||
index_topk: int = 2048,
|
||||
) -> list[BenchmarkResult]:
|
||||
"""
|
||||
Unified batched MLA benchmark runner for all backends.
|
||||
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
|
||||
flashinfer_mla_sparse, flashmla_sparse
|
||||
|
||||
This function reuses backend initialization across multiple benchmarks
|
||||
to avoid setup/teardown overhead.
|
||||
@@ -707,6 +748,7 @@ def _run_mla_benchmark_batched(
|
||||
configs_with_params: List of (config, threshold, num_splits) tuples
|
||||
- threshold: reorder_batch_threshold (FlashAttn/FlashMLA only)
|
||||
- num_splits: num_kv_splits (CUTLASS only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
|
||||
Returns:
|
||||
List of BenchmarkResult objects
|
||||
@@ -730,19 +772,27 @@ def _run_mla_benchmark_batched(
|
||||
if mla_dims is None:
|
||||
mla_dims = setup_mla_dims("deepseek-v3")
|
||||
|
||||
# Determine if this is a sparse backend
|
||||
is_sparse = backend_cfg.get("is_sparse", False)
|
||||
|
||||
# Create and set vLLM config for MLA (reused across all benchmarks)
|
||||
vllm_config = create_minimal_vllm_config(
|
||||
model_name="deepseek-v3", # Used only for model path
|
||||
block_size=block_size,
|
||||
mla_dims=mla_dims, # Use custom dims from config or default
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Create backend impl, layer, and builder (reused across benchmarks)
|
||||
impl, layer, builder_instance = _create_backend_impl(
|
||||
backend_cfg, mla_dims, vllm_config, device
|
||||
# Create backend impl, layer, builder, and indexer (reused across benchmarks)
|
||||
impl, layer, builder_instance, indexer = _create_backend_impl(
|
||||
backend_cfg,
|
||||
mla_dims,
|
||||
vllm_config,
|
||||
device,
|
||||
index_topk=index_topk if is_sparse else None,
|
||||
)
|
||||
|
||||
# Run each benchmark with the shared impl
|
||||
@@ -768,6 +818,7 @@ def _run_mla_benchmark_batched(
|
||||
backend_cfg,
|
||||
mla_dims,
|
||||
device,
|
||||
indexer=indexer,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
@@ -793,20 +844,24 @@ def run_mla_benchmark(
|
||||
config,
|
||||
reorder_batch_threshold: int | None = None,
|
||||
num_kv_splits: int | None = None,
|
||||
index_topk: int = 2048,
|
||||
) -> BenchmarkResult | list[BenchmarkResult]:
|
||||
"""
|
||||
Unified MLA benchmark runner for all backends.
|
||||
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla
|
||||
Works for: flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
|
||||
flashinfer_mla_sparse, flashmla_sparse
|
||||
|
||||
Always uses batched execution internally for optimal performance.
|
||||
|
||||
Args:
|
||||
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla)
|
||||
backend: Backend name (flashattn_mla, flashmla, flashinfer_mla, cutlass_mla,
|
||||
flashinfer_mla_sparse, flashmla_sparse)
|
||||
config: BenchmarkConfig or list of (BenchmarkConfig, param) tuples
|
||||
reorder_batch_threshold: Threshold override for FlashAttn/FlashMLA
|
||||
(single config mode only)
|
||||
num_kv_splits: Number of KV splits for CUTLASS (single config mode only)
|
||||
index_topk: Topk value for sparse MLA backends (default 2048)
|
||||
|
||||
Returns:
|
||||
BenchmarkResult (single mode) or list of BenchmarkResult (batched mode)
|
||||
@@ -816,9 +871,9 @@ def run_mla_benchmark(
|
||||
# Already in batched format
|
||||
if len(config) > 0 and isinstance(config[0], tuple):
|
||||
# Format: [(cfg, param), ...] where param is threshold or num_splits
|
||||
if backend in ("flashattn_mla", "flashmla"):
|
||||
if backend in ("flashattn_mla", "flashmla", "flashmla_sparse"):
|
||||
configs_with_params = [(cfg, param, None) for cfg, param in config]
|
||||
else: # cutlass_mla or flashinfer_mla
|
||||
else: # cutlass_mla, flashinfer_mla, or sparse backends
|
||||
configs_with_params = [(cfg, None, param) for cfg, param in config]
|
||||
else:
|
||||
# Format: [cfg, ...] - just configs
|
||||
@@ -830,7 +885,7 @@ def run_mla_benchmark(
|
||||
return_single = True
|
||||
|
||||
# Use unified batched execution
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params)
|
||||
results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk)
|
||||
|
||||
# Return single result or list based on input
|
||||
return results[0] if return_single else results
|
||||
|
||||
Reference in New Issue
Block a user