[ROCm] Refactor ROCm attention backend selection logic (#35246)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2026-03-05 08:51:26 -08:00
committed by GitHub
parent 3ee68590c7
commit 8c760b6ab6
6 changed files with 171 additions and 115 deletions

View File

@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
@staticmethod
def get_name() -> str:
@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only supported block_size is 1
return block_size is None or block_size == 1
@dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata):

View File

@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only unsupported block_size is 1
return block_size is None or block_size != 1
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True

View File

@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
from vllm.platforms.rocm import on_mi3xx
# DeviceCapability is currently created using torch.cuda.get_device_capability()
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
# more reliable.
return on_mi3xx()
class AiterFlashAttentionImpl(AttentionImpl):
def __init__(