[ROCm] Refactor ROCm attention backend selection logic (#35246)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
# The only supported block_size is 1
|
||||
return block_size is None or block_size == 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
|
||||
@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
# The only unsupported block_size is 1
|
||||
return block_size is None or block_size != 1
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
# DeviceCapability is currently created using torch.cuda.get_device_capability()
|
||||
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
|
||||
# more reliable.
|
||||
return on_mi3xx()
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user