diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 7c60a136f..ae9dfb02b 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -127,8 +127,8 @@ Priority is **1 = highest** (tried first). | 3 | `FLASH_ATTN_MLA` | | 4 | `FLASHMLA` | | 5 | `TRITON_MLA` | -| 6 | `FLASHMLA_SPARSE` | -| 7 | `FLASHINFER_MLA_SPARSE` | +| 6 | `FLASHINFER_MLA_SPARSE`**\*** | +| 7 | `FLASHMLA_SPARSE` | **Ampere/Hopper (SM 8.x-9.x):** @@ -140,6 +140,8 @@ Priority is **1 = highest** (tried first). | 4 | `TRITON_MLA` | | 5 | `FLASHMLA_SPARSE` | +> **\*** For sparse MLA, FP8 KV cache always prefers `FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` is preferred for low query-head counts (<= 16), while `FLASHMLA_SPARSE` is preferred otherwise. +> > **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details. ## Legend diff --git a/tools/pre_commit/generate_attention_backend_docs.py b/tools/pre_commit/generate_attention_backend_docs.py index 2df46db81..078404f21 100644 --- a/tools/pre_commit/generate_attention_backend_docs.py +++ b/tools/pre_commit/generate_attention_backend_docs.py @@ -1262,14 +1262,23 @@ When no backend is specified (the default): """ -def _priority_table(title: str, backends: list[str]) -> list[str]: +def _priority_table( + title: str, + backends: list[str], + annotations: dict[str, str] | None = None, +) -> list[str]: """Generate a priority table for a list of backends.""" + + def _fmt(b: str) -> str: + suffix = annotations.get(b, "") if annotations else "" + return f"`{b}`{suffix}" + return [ f"**{title}:**", "", "| Priority | Backend |", "| -------- | ------- |", - *[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)], + *[f"| {i} | {_fmt(b)} |" for i, b in enumerate(backends, 1)], "", ] @@ -1298,11 +1307,25 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str: lines.extend(["### MLA Attention (DeepSeek-style)", ""]) + mla_sm100_annotations = { + "FLASHINFER_MLA_SPARSE": "**\\***", + } if "mla_sm100" in priorities: - lines.extend(_priority_table(sm100, priorities["mla_sm100"])) + lines.extend( + _priority_table(sm100, priorities["mla_sm100"], mla_sm100_annotations) + ) if "mla_default" in priorities: lines.extend(_priority_table(ampere, priorities["mla_default"])) + if "mla_sm100" in priorities: + lines.append( + "> **\\*** For sparse MLA, FP8 KV cache always prefers " + "`FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` " + "is preferred for low query-head counts (<= 16), while " + "`FLASHMLA_SPARSE` is preferred otherwise." + ) + lines.append(">") + lines.append( "> **Note:** ROCm and CPU platforms have their own selection logic. " "See the platform-specific documentation for details." diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2025c41ab..8bf6c8e4b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,6 +4,8 @@ pynvml. However, it should not initialize cuda context. """ +from __future__ import annotations + import os from collections.abc import Callable from datetime import timedelta @@ -49,21 +51,34 @@ def _get_backend_priorities( use_mla: bool, device_capability: DeviceCapability, num_heads: int | None = None, + kv_cache_dtype: CacheDType | None = None, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" if use_mla: if device_capability.major == 10: - # Prefer FlashInfer at low head counts (FlashMLA uses padding) - if num_heads is not None and num_heads <= 16: + # Sparse MLA backend priorities + # See https://github.com/vllm-project/vllm/issues/35807 for + # benchmark results + if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + # Prefer FlashInfer for fp8 kv cache sparse_backends = [ AttentionBackendEnum.FLASHINFER_MLA_SPARSE, AttentionBackendEnum.FLASHMLA_SPARSE, ] else: - sparse_backends = [ - AttentionBackendEnum.FLASHMLA_SPARSE, - AttentionBackendEnum.FLASHINFER_MLA_SPARSE, - ] + # BF16 KV Cache + # Prefer FlashInfer at low head counts (FlashMLA uses padding) + if num_heads is not None and num_heads <= 16: + sparse_backends = [ + AttentionBackendEnum.FLASHINFER_MLA_SPARSE, + AttentionBackendEnum.FLASHMLA_SPARSE, + ] + else: + sparse_backends = [ + AttentionBackendEnum.FLASHMLA_SPARSE, + AttentionBackendEnum.FLASHINFER_MLA_SPARSE, + ] + return [ AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.CUTLASS_MLA, @@ -165,7 +180,7 @@ class CudaPlatformBase(Platform): pass @classmethod - def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config model_config = vllm_config.model_config @@ -198,11 +213,11 @@ class CudaPlatformBase(Platform): def get_valid_backends( cls, device_capability: DeviceCapability, - attn_selector_config: "AttentionSelectorConfig", + attn_selector_config: AttentionSelectorConfig, num_heads: int | None = None, ) -> tuple[ - list[tuple["AttentionBackendEnum", int]], - dict["AttentionBackendEnum", tuple[int, list[str]]], + list[tuple[AttentionBackendEnum, int]], + dict[AttentionBackendEnum, tuple[int, list[str]]], ]: valid_backends_priorities = [] invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} @@ -211,6 +226,7 @@ class CudaPlatformBase(Platform): attn_selector_config.use_mla, device_capability, num_heads, + attn_selector_config.kv_cache_dtype, ) for priority, backend in enumerate(backend_priorities): try: @@ -231,8 +247,8 @@ class CudaPlatformBase(Platform): @classmethod def get_attn_backend_cls( cls, - selected_backend: "AttentionBackendEnum | None", - attn_selector_config: "AttentionSelectorConfig", + selected_backend: AttentionBackendEnum | None, + attn_selector_config: AttentionSelectorConfig, num_heads: int | None = None, ) -> str: device_capability = cls.get_device_capability() @@ -324,7 +340,7 @@ class CudaPlatformBase(Platform): return selected_backend.get_path() @classmethod - def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: + def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]: if cls.has_device_capability(80): return [ AttentionBackendEnum.FLASH_ATTN, @@ -345,8 +361,8 @@ class CudaPlatformBase(Platform): cls, head_size: int, dtype: torch.dtype, - backend: "AttentionBackendEnum | None" = None, - ) -> "AttentionBackendEnum": + backend: AttentionBackendEnum | None = None, + ) -> AttentionBackendEnum: if backend is not None: assert backend in cls.get_supported_vit_attn_backends(), ( f"Backend {backend} is not supported for vit attention. "