[Attention] Add FlashInfer Sparse MLA backend (#33451)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-02-12 12:21:54 -05:00
committed by GitHub
parent 334c715e0f
commit f2c47886fd
24 changed files with 1181 additions and 408 deletions

View File

@@ -16,13 +16,32 @@ from batch_spec import get_batch_type, parse_batch_spec
from rich.console import Console
from rich.table import Table
def batch_spec_sort_key(spec: str) -> tuple[int, int, int]:
"""
Extract sorting key from batch spec: (batch_size, max_q_len, max_kv_len).
This ensures results are sorted by batch size first, then query length,
then sequence length, rather than alphabetically.
"""
try:
requests = parse_batch_spec(spec)
batch_size = len(requests)
max_q_len = max(r.q_len for r in requests) if requests else 0
max_kv_len = max(r.kv_len for r in requests) if requests else 0
return (batch_size, max_q_len, max_kv_len)
except Exception:
# Fallback for unparseable specs
return (0, 0, 0)
# Mock classes for vLLM attention infrastructure
class MockHfConfig:
"""Mock HuggingFace config that satisfies vLLM's requirements."""
def __init__(self, mla_dims: dict):
def __init__(self, mla_dims: dict, index_topk: int | None = None):
self.num_attention_heads = mla_dims["num_q_heads"]
self.num_key_value_heads = mla_dims["num_kv_heads"]
self.hidden_size = mla_dims["head_dim"] * mla_dims["num_q_heads"]
@@ -33,6 +52,8 @@ class MockHfConfig:
self.qk_rope_head_dim = mla_dims["qk_rope_head_dim"]
self.v_head_dim = mla_dims["v_head_dim"]
self.qk_head_dim = mla_dims["qk_nope_head_dim"] + mla_dims["qk_rope_head_dim"]
if index_topk is not None:
self.index_topk = index_topk
def get_text_config(self):
return self
@@ -83,6 +104,38 @@ class MockKVBProj:
return (result,) # Return as tuple to match ColumnParallelLinear API
class MockIndexer:
"""Mock Indexer for sparse MLA backends.
Provides topk_indices_buffer that sparse MLA backends use to determine
which KV cache slots to attend to for each token.
"""
def __init__(
self,
max_num_tokens: int,
topk_tokens: int,
device: torch.device,
):
self.topk_tokens = topk_tokens
self.topk_indices_buffer = torch.zeros(
(max_num_tokens, topk_tokens),
dtype=torch.int32,
device=device,
)
def fill_random_indices(self, num_tokens: int, max_kv_len: int):
"""Fill topk_indices_buffer with random valid indices for benchmarking."""
indices = torch.randint(
0,
max_kv_len,
(num_tokens, self.topk_tokens),
dtype=torch.int32,
device=self.topk_indices_buffer.device,
)
self.topk_indices_buffer[:num_tokens] = indices
class MockLayer(AttentionLayerBase):
"""Mock attention layer with scale parameters and impl.
@@ -327,6 +380,9 @@ class ResultsFormatter:
specs_order.append(spec)
by_spec[spec][r.config.backend] = r
# Sort specs by (batch_size, q_len, kv_len) instead of alphabetically
specs_order = sorted(by_spec.keys(), key=batch_spec_sort_key)
# Create shortened backend names for display
def shorten_backend_name(name: str) -> str:
"""Shorten long backend names for table display."""
@@ -493,10 +549,11 @@ def get_attention_scale(head_dim: int) -> float:
def is_mla_backend(backend: str) -> bool:
"""
Check if backend is an MLA backend using the backend's is_mla() property.
Check if backend is an MLA backend using the AttentionBackendEnum.
Args:
backend: Backend name (e.g., "CUTLASS_MLA", "FLASHINFER_MLA")
backend: Backend name matching AttentionBackendEnum exactly
(e.g., "FLASHMLA_SPARSE")
Returns:
True if the backend is an MLA backend, False otherwise
@@ -504,7 +561,8 @@ def is_mla_backend(backend: str) -> bool:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
try:
backend_class = AttentionBackendEnum[backend.upper()].get_class()
backend_enum = AttentionBackendEnum[backend]
backend_class = backend_enum.get_class()
return backend_class.is_mla()
except (KeyError, ValueError, ImportError):
except (KeyError, ValueError, ImportError, AttributeError):
return False