diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 7b643a46b..f407f1ec7 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -171,9 +171,9 @@ Priority is **1 = highest** (tried first). | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | -| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | All | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | +| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | | `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | @@ -210,7 +210,7 @@ configuration. | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | -| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | bf16 | `auto` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 585ad1d79..3af817a2e 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -252,7 +252,7 @@ class AttentionBackend(ABC): else: invalid_reasons.append("non-MLA not supported") if has_sink and not cls.supports_sink(): - invalid_reasons.append("sink setting not supported") + invalid_reasons.append("attention sinks not supported") if use_sparse != cls.is_sparse(): if use_sparse: invalid_reasons.append("sparse not supported") diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 57a1d32d2..dde1fb3eb 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -8,6 +8,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -21,6 +22,15 @@ from vllm.v1.kv_cache_interface import AttentionSpec class AiterMLABackend(MLACommonBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] + @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [1] diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index 47f1c06ea..b1d503ca4 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -9,6 +9,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import ( get_mla_dims, @@ -21,6 +22,7 @@ from vllm.v1.attention.backend import ( AttentionMetadata, AttentionMetadataBuilder, CommonAttentionMetadata, + MultipleOf, SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.flashmla_sparse import ( @@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton( class ROCMAiterMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [1] @staticmethod def get_name() -> str: @@ -105,10 +115,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend): ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [576] - @classmethod def is_mla(cls) -> bool: return True @@ -117,11 +123,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend): def is_sparse(cls) -> bool: return True - @classmethod - def supports_block_size(cls, block_size: int | None) -> bool: - # The only supported block_size is 1 - return block_size is None or block_size == 1 - @dataclass class ROCMAiterMLASparseMetadata(AttentionMetadata): diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index a950288b6..f6c1790f6 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -45,11 +45,6 @@ class TritonMLABackend(MLACommonBackend): def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return True - @classmethod - def supports_block_size(cls, block_size: int | None) -> bool: - # The only unsupported block_size is 1 - return block_size is None or block_size != 1 - class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index c0269ec68..da385896f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -9,6 +9,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.platforms import current_platform @@ -732,6 +733,13 @@ class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 130ccaa2d..dbfb924a8 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, ) -from vllm.v1.attention.backend import AttentionLayer, AttentionType +from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.rocm_attn import ( RocmAttentionBackend, @@ -25,6 +25,22 @@ logger = init_logger(__name__) class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): accept_output_buffer: bool = True + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + return head_size >= 32 + + @classmethod + def supports_mm_prefix(cls) -> bool: + return True + + @classmethod + def supports_sink(cls) -> bool: + return True + forward_includes_kv_cache_update: bool = False @staticmethod diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index b53170c98..e8d34822e 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -9,6 +9,7 @@ import torch from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend): torch.bfloat16, torch.float32, ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + "fp8", + "fp8_e4m3", + "fp8_e5m2", + ] @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: @@ -185,15 +193,12 @@ class RocmAttentionBackend(AttentionBackend): return [32, 64, 80, 96, 128, 160, 192, 224, 256] @classmethod - def validate_head_size(cls, head_size: int) -> None: - if not cls.supports_head_size(head_size): - attn_type = cls.__name__.removesuffix("Backend") - raise ValueError( - f"Head size {head_size} is not supported by {attn_type}. " - f"Supported head sizes are: {cls.get_supported_head_sizes()}. " - "Set --attention-backend=FLEX_ATTENTION to use " - "FlexAttention backend which supports all head sizes." - ) + def supports_mm_prefix(cls) -> bool: + return True + + @classmethod + def supports_sink(cls) -> bool: + return True forward_includes_kv_cache_update: bool = False @@ -275,8 +280,6 @@ class RocmAttentionImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads - RocmAttentionBackend.validate_head_size(head_size) - self.fp8_dtype = current_platform.fp8_dtype() self.sinks = sinks