[Perf] Set Flashinfer sparse MLA as default backend for FP8 kv cache (#37252)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
Wei Zhao
2026-03-17 16:09:20 -04:00
committed by GitHub
parent e78821b438
commit b36adfa349
3 changed files with 61 additions and 20 deletions

View File

@@ -127,8 +127,8 @@ Priority is **1 = highest** (tried first).
| 3 | `FLASH_ATTN_MLA` |
| 4 | `FLASHMLA` |
| 5 | `TRITON_MLA` |
| 6 | `FLASHMLA_SPARSE` |
| 7 | `FLASHINFER_MLA_SPARSE` |
| 6 | `FLASHINFER_MLA_SPARSE`**\*** |
| 7 | `FLASHMLA_SPARSE` |
**Ampere/Hopper (SM 8.x-9.x):**
@@ -140,6 +140,8 @@ Priority is **1 = highest** (tried first).
| 4 | `TRITON_MLA` |
| 5 | `FLASHMLA_SPARSE` |
> **\*** For sparse MLA, FP8 KV cache always prefers `FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` is preferred for low query-head counts (<= 16), while `FLASHMLA_SPARSE` is preferred otherwise.
>
> **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details.
## Legend

View File

@@ -1262,14 +1262,23 @@ When no backend is specified (the default):
"""
def _priority_table(title: str, backends: list[str]) -> list[str]:
def _priority_table(
title: str,
backends: list[str],
annotations: dict[str, str] | None = None,
) -> list[str]:
"""Generate a priority table for a list of backends."""
def _fmt(b: str) -> str:
suffix = annotations.get(b, "") if annotations else ""
return f"`{b}`{suffix}"
return [
f"**{title}:**",
"",
"| Priority | Backend |",
"| -------- | ------- |",
*[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)],
*[f"| {i} | {_fmt(b)} |" for i, b in enumerate(backends, 1)],
"",
]
@@ -1298,11 +1307,25 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
lines.extend(["### MLA Attention (DeepSeek-style)", ""])
mla_sm100_annotations = {
"FLASHINFER_MLA_SPARSE": "**\\***",
}
if "mla_sm100" in priorities:
lines.extend(_priority_table(sm100, priorities["mla_sm100"]))
lines.extend(
_priority_table(sm100, priorities["mla_sm100"], mla_sm100_annotations)
)
if "mla_default" in priorities:
lines.extend(_priority_table(ampere, priorities["mla_default"]))
if "mla_sm100" in priorities:
lines.append(
"> **\\*** For sparse MLA, FP8 KV cache always prefers "
"`FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` "
"is preferred for low query-head counts (<= 16), while "
"`FLASHMLA_SPARSE` is preferred otherwise."
)
lines.append(">")
lines.append(
"> **Note:** ROCm and CPU platforms have their own selection logic. "
"See the platform-specific documentation for details."

View File

@@ -4,6 +4,8 @@
pynvml. However, it should not initialize cuda context.
"""
from __future__ import annotations
import os
from collections.abc import Callable
from datetime import timedelta
@@ -49,21 +51,34 @@ def _get_backend_priorities(
use_mla: bool,
device_capability: DeviceCapability,
num_heads: int | None = None,
kv_cache_dtype: CacheDType | None = None,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
if use_mla:
if device_capability.major == 10:
# Prefer FlashInfer at low head counts (FlashMLA uses padding)
if num_heads is not None and num_heads <= 16:
# Sparse MLA backend priorities
# See https://github.com/vllm-project/vllm/issues/35807 for
# benchmark results
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
# Prefer FlashInfer for fp8 kv cache
sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
sparse_backends = [
AttentionBackendEnum.FLASHMLA_SPARSE,
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
]
# BF16 KV Cache
# Prefer FlashInfer at low head counts (FlashMLA uses padding)
if num_heads is not None and num_heads <= 16:
sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
sparse_backends = [
AttentionBackendEnum.FLASHMLA_SPARSE,
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
]
return [
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.CUTLASS_MLA,
@@ -165,7 +180,7 @@ class CudaPlatformBase(Platform):
pass
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config
@@ -198,11 +213,11 @@ class CudaPlatformBase(Platform):
def get_valid_backends(
cls,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
attn_selector_config: AttentionSelectorConfig,
num_heads: int | None = None,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", tuple[int, list[str]]],
list[tuple[AttentionBackendEnum, int]],
dict[AttentionBackendEnum, tuple[int, list[str]]],
]:
valid_backends_priorities = []
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
@@ -211,6 +226,7 @@ class CudaPlatformBase(Platform):
attn_selector_config.use_mla,
device_capability,
num_heads,
attn_selector_config.kv_cache_dtype,
)
for priority, backend in enumerate(backend_priorities):
try:
@@ -231,8 +247,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
selected_backend: AttentionBackendEnum | None,
attn_selector_config: AttentionSelectorConfig,
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
@@ -324,7 +340,7 @@ class CudaPlatformBase(Platform):
return selected_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
if cls.has_device_capability(80):
return [
AttentionBackendEnum.FLASH_ATTN,
@@ -345,8 +361,8 @@ class CudaPlatformBase(Platform):
cls,
head_size: int,
dtype: torch.dtype,
backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum":
backend: AttentionBackendEnum | None = None,
) -> AttentionBackendEnum:
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "