[ROCm][CI] Fix ROCm attention backend validation for head sizes, block sizes, and compute capability checks (#36292)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-09 12:02:41 -05:00
committed by GitHub
parent 55d27cca55
commit c174d54f86
7 changed files with 55 additions and 9 deletions

View File

@@ -213,4 +213,4 @@ configuration.
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |

View File

@@ -29,11 +29,18 @@ def mock_vllm_config():
@pytest.fixture @pytest.fixture
def mock_on_gfx9(): def mock_on_gfx9():
"""Mock the on_gfx9 function to return True.""" """Mock gfx9 arch detection to return True."""
with patch("vllm.platforms.rocm.on_gfx9", return_value=True): with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
yield yield
@pytest.fixture
def mock_on_mi3xx():
"""Mock mi3xx arch detection to return True."""
with patch("vllm.platforms.rocm.on_mi3xx", return_value=True):
yield
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_vars, selected_backend, expected_backend_path", "env_vars, selected_backend, expected_backend_path",
[ [
@@ -122,6 +129,7 @@ def test_standard_attention_backend_selection(
expected_backend_path, expected_backend_path,
mock_vllm_config, mock_vllm_config,
mock_on_gfx9, mock_on_gfx9,
mock_on_mi3xx,
monkeypatch, monkeypatch,
): ):
"""Test standard attention backend selection with various configurations.""" """Test standard attention backend selection with various configurations."""
@@ -313,16 +321,16 @@ def test_mla_backend_selection(
assert backend_path == expected_backend_path assert backend_path == expected_backend_path
def test_aiter_fa_requires_gfx9(mock_vllm_config): def test_aiter_fa_requires_mi3xx(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture.""" """Test that ROCM_AITER_FA requires mi3xx architecture."""
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
# Mock on_gfx9 to return False # Mock on_mi3xx to return False (used by supports_compute_capability)
with ( with (
patch("vllm.platforms.rocm.on_gfx9", return_value=False), patch("vllm.platforms.rocm.on_mi3xx", return_value=False),
pytest.raises( pytest.raises(
ValueError, ValueError,
match="only supported on gfx9", match="compute capability not supported",
), ),
): ):
attn_selector_config = AttentionSelectorConfig( attn_selector_config = AttentionSelectorConfig(
@@ -342,11 +350,12 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
def test_sparse_not_supported(mock_vllm_config): def test_sparse_not_supported(mock_vllm_config):
"""Test that sparse attention is not supported on ROCm.""" """Test that sparse MLA without use_mla flag raises an error."""
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
with pytest.raises( with pytest.raises(
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1" ValueError,
match="No valid attention backend found",
): ):
attn_selector_config = AttentionSelectorConfig( attn_selector_config = AttentionSelectorConfig(
head_size=128, head_size=128,

View File

@@ -31,6 +31,10 @@ class AiterMLABackend(MLACommonBackend):
"fp8_e5m2", "fp8_e5m2",
] ]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1] return [1]

View File

@@ -19,6 +19,7 @@ from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf,
is_quantized_kv_cache, is_quantized_kv_cache,
) )
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
@@ -33,6 +34,20 @@ class TritonMLABackend(MLACommonBackend):
"bfloat16", "bfloat16",
] ]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return []
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "TRITON_MLA" return "TRITON_MLA"

View File

@@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)] return [MultipleOf(16)]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0
@classmethod @classmethod
def supports_head_size(cls, head_size: int) -> bool: def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32 return head_size >= 32

View File

@@ -188,6 +188,12 @@ class RocmAttentionBackend(AttentionBackend):
# uses our optimized kernel logic. # uses our optimized kernel logic.
return [16, 32, 544] return [16, 32, 544]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size in (16, 32, 544)
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 128, 160, 192, 224, 256] return [32, 64, 80, 96, 128, 160, 192, 224, 256]

View File

@@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend):
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)] return [MultipleOf(16)]
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
if block_size is None:
return True
return block_size % 16 == 0
forward_includes_kv_cache_update: bool = False forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod