Reenable features for ROCm attention backends (#36185)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-03-05 22:21:06 -06:00
committed by GitHub
parent 0a49676fb0
commit c5362c739f
8 changed files with 66 additions and 33 deletions

View File

@@ -252,7 +252,7 @@ class AttentionBackend(ABC):
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
invalid_reasons.append("attention sinks not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")

View File

@@ -8,6 +8,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
@@ -21,6 +22,15 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class AiterMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]

View File

@@ -9,6 +9,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
@@ -21,6 +22,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
@@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton(
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
@staticmethod
def get_name() -> str:
@@ -105,10 +115,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@@ -117,11 +123,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only supported block_size is 1
return block_size is None or block_size == 1
@dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata):

View File

@@ -45,11 +45,6 @@ class TritonMLABackend(MLACommonBackend):
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only unsupported block_size is 1
return block_size is None or block_size != 1
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True

View File

@@ -9,6 +9,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
@@ -732,6 +733,13 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:

View File

@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backend import AttentionLayer, AttentionType
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
@@ -25,6 +25,22 @@ logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
forward_includes_kv_cache_update: bool = False
@staticmethod

View File

@@ -9,6 +9,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
@@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
@@ -185,15 +193,12 @@ class RocmAttentionBackend(AttentionBackend):
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
@classmethod
def validate_head_size(cls, head_size: int) -> None:
if not cls.supports_head_size(head_size):
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
forward_includes_kv_cache_update: bool = False
@@ -275,8 +280,6 @@ class RocmAttentionImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
RocmAttentionBackend.validate_head_size(head_size)
self.fp8_dtype = current_platform.fp8_dtype()
self.sinks = sinks