[Perf] Set Flashinfer sparse MLA as default backend for FP8 kv cache (#37252)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
@@ -127,8 +127,8 @@ Priority is **1 = highest** (tried first).
|
||||
| 3 | `FLASH_ATTN_MLA` |
|
||||
| 4 | `FLASHMLA` |
|
||||
| 5 | `TRITON_MLA` |
|
||||
| 6 | `FLASHMLA_SPARSE` |
|
||||
| 7 | `FLASHINFER_MLA_SPARSE` |
|
||||
| 6 | `FLASHINFER_MLA_SPARSE`**\*** |
|
||||
| 7 | `FLASHMLA_SPARSE` |
|
||||
|
||||
**Ampere/Hopper (SM 8.x-9.x):**
|
||||
|
||||
@@ -140,6 +140,8 @@ Priority is **1 = highest** (tried first).
|
||||
| 4 | `TRITON_MLA` |
|
||||
| 5 | `FLASHMLA_SPARSE` |
|
||||
|
||||
> **\*** For sparse MLA, FP8 KV cache always prefers `FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` is preferred for low query-head counts (<= 16), while `FLASHMLA_SPARSE` is preferred otherwise.
|
||||
>
|
||||
> **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details.
|
||||
|
||||
## Legend
|
||||
|
||||
@@ -1262,14 +1262,23 @@ When no backend is specified (the default):
|
||||
"""
|
||||
|
||||
|
||||
def _priority_table(title: str, backends: list[str]) -> list[str]:
|
||||
def _priority_table(
|
||||
title: str,
|
||||
backends: list[str],
|
||||
annotations: dict[str, str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Generate a priority table for a list of backends."""
|
||||
|
||||
def _fmt(b: str) -> str:
|
||||
suffix = annotations.get(b, "") if annotations else ""
|
||||
return f"`{b}`{suffix}"
|
||||
|
||||
return [
|
||||
f"**{title}:**",
|
||||
"",
|
||||
"| Priority | Backend |",
|
||||
"| -------- | ------- |",
|
||||
*[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)],
|
||||
*[f"| {i} | {_fmt(b)} |" for i, b in enumerate(backends, 1)],
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -1298,11 +1307,25 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
|
||||
|
||||
lines.extend(["### MLA Attention (DeepSeek-style)", ""])
|
||||
|
||||
mla_sm100_annotations = {
|
||||
"FLASHINFER_MLA_SPARSE": "**\\***",
|
||||
}
|
||||
if "mla_sm100" in priorities:
|
||||
lines.extend(_priority_table(sm100, priorities["mla_sm100"]))
|
||||
lines.extend(
|
||||
_priority_table(sm100, priorities["mla_sm100"], mla_sm100_annotations)
|
||||
)
|
||||
if "mla_default" in priorities:
|
||||
lines.extend(_priority_table(ampere, priorities["mla_default"]))
|
||||
|
||||
if "mla_sm100" in priorities:
|
||||
lines.append(
|
||||
"> **\\*** For sparse MLA, FP8 KV cache always prefers "
|
||||
"`FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` "
|
||||
"is preferred for low query-head counts (<= 16), while "
|
||||
"`FLASHMLA_SPARSE` is preferred otherwise."
|
||||
)
|
||||
lines.append(">")
|
||||
|
||||
lines.append(
|
||||
"> **Note:** ROCm and CPU platforms have their own selection logic. "
|
||||
"See the platform-specific documentation for details."
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
pynvml. However, it should not initialize cuda context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
@@ -49,21 +51,34 @@ def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
device_capability: DeviceCapability,
|
||||
num_heads: int | None = None,
|
||||
kv_cache_dtype: CacheDType | None = None,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
# Prefer FlashInfer at low head counts (FlashMLA uses padding)
|
||||
if num_heads is not None and num_heads <= 16:
|
||||
# Sparse MLA backend priorities
|
||||
# See https://github.com/vllm-project/vllm/issues/35807 for
|
||||
# benchmark results
|
||||
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||
# Prefer FlashInfer for fp8 kv cache
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
]
|
||||
# BF16 KV Cache
|
||||
# Prefer FlashInfer at low head counts (FlashMLA uses padding)
|
||||
if num_heads is not None and num_heads <= 16:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
sparse_backends = [
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
|
||||
]
|
||||
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
@@ -165,7 +180,7 @@ class CudaPlatformBase(Platform):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
@@ -198,11 +213,11 @@ class CudaPlatformBase(Platform):
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
device_capability: DeviceCapability,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
attn_selector_config: AttentionSelectorConfig,
|
||||
num_heads: int | None = None,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", tuple[int, list[str]]],
|
||||
list[tuple[AttentionBackendEnum, int]],
|
||||
dict[AttentionBackendEnum, tuple[int, list[str]]],
|
||||
]:
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
|
||||
@@ -211,6 +226,7 @@ class CudaPlatformBase(Platform):
|
||||
attn_selector_config.use_mla,
|
||||
device_capability,
|
||||
num_heads,
|
||||
attn_selector_config.kv_cache_dtype,
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
@@ -231,8 +247,8 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum | None",
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
selected_backend: AttentionBackendEnum | None,
|
||||
attn_selector_config: AttentionSelectorConfig,
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
device_capability = cls.get_device_capability()
|
||||
@@ -324,7 +340,7 @@ class CudaPlatformBase(Platform):
|
||||
return selected_backend.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
|
||||
if cls.has_device_capability(80):
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
@@ -345,8 +361,8 @@ class CudaPlatformBase(Platform):
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
backend: "AttentionBackendEnum | None" = None,
|
||||
) -> "AttentionBackendEnum":
|
||||
backend: AttentionBackendEnum | None = None,
|
||||
) -> AttentionBackendEnum:
|
||||
if backend is not None:
|
||||
assert backend in cls.get_supported_vit_attn_backends(), (
|
||||
f"Backend {backend} is not supported for vit attention. "
|
||||
|
||||
Reference in New Issue
Block a user