[Attention] Add FlashInfer Sparse MLA backend (#33451)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-02-12 12:21:54 -05:00
committed by GitHub
parent 334c715e0f
commit f2c47886fd
24 changed files with 1181 additions and 408 deletions

View File

@@ -40,29 +40,29 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
# ============================================================================
_BACKEND_CONFIG = {
"flash": {
"module": "vllm.v1.attention.backends.flash_attn",
"backend_class": "FlashAttentionBackend",
},
"triton": {
"module": "vllm.v1.attention.backends.triton_attn",
"backend_class": "TritonAttentionBackend",
},
"flashinfer": {
"module": "vllm.v1.attention.backends.flashinfer",
"backend_class": "FlashInferBackend",
},
}
def _get_backend_config(backend: str) -> dict:
if backend not in _BACKEND_CONFIG:
"""
Get backend configuration from AttentionBackendEnum.
Args:
backend: Backend name matching AttentionBackendEnum exactly
(e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER")
Returns:
Dict with backend_class
"""
from vllm.v1.attention.backends.registry import AttentionBackendEnum
try:
backend_enum = AttentionBackendEnum[backend]
backend_class = backend_enum.get_class()
except (KeyError, ValueError) as e:
valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"]
raise ValueError(
f"Unknown backend: {backend}. "
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
)
return _BACKEND_CONFIG[backend]
f"Unknown backend: {backend}. Valid backends: {valid_backends}"
) from e
return {"backend_class": backend_class}
@contextmanager
@@ -205,10 +205,7 @@ def _create_backend_impl(
dtype: torch.dtype,
):
"""Create backend implementation instance."""
import importlib
backend_module = importlib.import_module(backend_cfg["module"])
backend_class = getattr(backend_module, backend_cfg["backend_class"])
backend_class = backend_cfg["backend_class"]
scale = get_attention_scale(config.head_dim)
@@ -247,7 +244,7 @@ def _create_metadata_builder(
# Flashinfer needs get_per_layer_parameters mocked since we don't have
# real model layers registered
if backend_name == "flashinfer":
if backend_name == "FLASHINFER":
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
@@ -438,7 +435,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
"""
Run standard attention benchmark with real kernels.
Supports: flash, triton, flashinfer
Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER
Args:
config: Benchmark configuration
@@ -453,7 +450,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
requests = parse_batch_spec(config.batch_spec)
if config.backend == "flashinfer":
if config.backend == "FLASHINFER":
requests = reorder_for_flashinfer(requests)
q_lens = [r.q_len for r in requests]