[Attention] Add FlashInfer Sparse MLA backend (#33451)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -40,29 +40,29 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_BACKEND_CONFIG = {
|
||||
"flash": {
|
||||
"module": "vllm.v1.attention.backends.flash_attn",
|
||||
"backend_class": "FlashAttentionBackend",
|
||||
},
|
||||
"triton": {
|
||||
"module": "vllm.v1.attention.backends.triton_attn",
|
||||
"backend_class": "TritonAttentionBackend",
|
||||
},
|
||||
"flashinfer": {
|
||||
"module": "vllm.v1.attention.backends.flashinfer",
|
||||
"backend_class": "FlashInferBackend",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_backend_config(backend: str) -> dict:
|
||||
if backend not in _BACKEND_CONFIG:
|
||||
"""
|
||||
Get backend configuration from AttentionBackendEnum.
|
||||
|
||||
Args:
|
||||
backend: Backend name matching AttentionBackendEnum exactly
|
||||
(e.g., "FLASH_ATTN", "TRITON_ATTN", "FLASHINFER")
|
||||
|
||||
Returns:
|
||||
Dict with backend_class
|
||||
"""
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
try:
|
||||
backend_enum = AttentionBackendEnum[backend]
|
||||
backend_class = backend_enum.get_class()
|
||||
except (KeyError, ValueError) as e:
|
||||
valid_backends = [b.name for b in AttentionBackendEnum if b.name != "CUSTOM"]
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend}. "
|
||||
f"Available: {', '.join(_BACKEND_CONFIG.keys())}"
|
||||
)
|
||||
return _BACKEND_CONFIG[backend]
|
||||
f"Unknown backend: {backend}. Valid backends: {valid_backends}"
|
||||
) from e
|
||||
|
||||
return {"backend_class": backend_class}
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -205,10 +205,7 @@ def _create_backend_impl(
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create backend implementation instance."""
|
||||
import importlib
|
||||
|
||||
backend_module = importlib.import_module(backend_cfg["module"])
|
||||
backend_class = getattr(backend_module, backend_cfg["backend_class"])
|
||||
backend_class = backend_cfg["backend_class"]
|
||||
|
||||
scale = get_attention_scale(config.head_dim)
|
||||
|
||||
@@ -247,7 +244,7 @@ def _create_metadata_builder(
|
||||
|
||||
# Flashinfer needs get_per_layer_parameters mocked since we don't have
|
||||
# real model layers registered
|
||||
if backend_name == "flashinfer":
|
||||
if backend_name == "FLASHINFER":
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
@@ -438,7 +435,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
"""
|
||||
Run standard attention benchmark with real kernels.
|
||||
|
||||
Supports: flash, triton, flashinfer
|
||||
Supports: FLASH_ATTN, TRITON_ATTN, FLASHINFER
|
||||
|
||||
Args:
|
||||
config: Benchmark configuration
|
||||
@@ -453,7 +450,7 @@ def run_attention_benchmark(config: BenchmarkConfig) -> BenchmarkResult:
|
||||
|
||||
requests = parse_batch_spec(config.batch_spec)
|
||||
|
||||
if config.backend == "flashinfer":
|
||||
if config.backend == "FLASHINFER":
|
||||
requests = reorder_for_flashinfer(requests)
|
||||
|
||||
q_lens = [r.q_len for r in requests]
|
||||
|
||||
Reference in New Issue
Block a user