[ROCm][CI] Fix ROCm attention backend validation for head sizes, block sizes, and compute capability checks (#36292)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -31,6 +31,10 @@ class AiterMLABackend(MLACommonBackend):
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [1]
|
||||
|
||||
@@ -19,6 +19,7 @@ from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
@@ -33,6 +34,20 @@ class TritonMLABackend(MLACommonBackend):
|
||||
"bfloat16",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
if block_size is None:
|
||||
return True
|
||||
return block_size % 16 == 0
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
if block_size is None:
|
||||
return True
|
||||
return block_size % 16 == 0
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@@ -188,6 +188,12 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
# uses our optimized kernel logic.
|
||||
return [16, 32, 544]
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
if block_size is None:
|
||||
return True
|
||||
return block_size in (16, 32, 544)
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
if block_size is None:
|
||||
return True
|
||||
return block_size % 16 == 0
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user