[ROCm][CI] Fix ROCm attention backend validation for head sizes, block sizes, and compute capability checks (#36292)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-09 12:02:41 -05:00
committed by GitHub
parent 55d27cca55
commit c174d54f86
7 changed files with 55 additions and 9 deletions

View File

@@ -31,6 +31,10 @@ class AiterMLABackend(MLACommonBackend):
"fp8_e5m2",
]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]

View File

@@ -19,6 +19,7 @@ from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
@@ -33,6 +34,20 @@ class TritonMLABackend(MLACommonBackend):
"bfloat16",
]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0
@staticmethod
def get_name() -> str:
return "TRITON_MLA"

View File

@@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0
@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32

View File

@@ -188,6 +188,12 @@ class RocmAttentionBackend(AttentionBackend):
# uses our optimized kernel logic.
return [16, 32, 544]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size in (16, 32, 544)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 128, 160, 192, 224, 256]

View File

@@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0
forward_includes_kv_cache_update: bool = False
@staticmethod