[ROCm][CI] Fix ROCm attention backend validation for head sizes, block sizes, and compute capability checks (#36292)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -213,4 +213,4 @@ configuration.
|
|||||||
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||||
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
||||||
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||||
|
|||||||
@@ -29,11 +29,18 @@ def mock_vllm_config():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_on_gfx9():
|
def mock_on_gfx9():
|
||||||
"""Mock the on_gfx9 function to return True."""
|
"""Mock gfx9 arch detection to return True."""
|
||||||
with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
|
with patch("vllm.platforms.rocm.on_gfx9", return_value=True):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_on_mi3xx():
|
||||||
|
"""Mock mi3xx arch detection to return True."""
|
||||||
|
with patch("vllm.platforms.rocm.on_mi3xx", return_value=True):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_vars, selected_backend, expected_backend_path",
|
"env_vars, selected_backend, expected_backend_path",
|
||||||
[
|
[
|
||||||
@@ -122,6 +129,7 @@ def test_standard_attention_backend_selection(
|
|||||||
expected_backend_path,
|
expected_backend_path,
|
||||||
mock_vllm_config,
|
mock_vllm_config,
|
||||||
mock_on_gfx9,
|
mock_on_gfx9,
|
||||||
|
mock_on_mi3xx,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
):
|
):
|
||||||
"""Test standard attention backend selection with various configurations."""
|
"""Test standard attention backend selection with various configurations."""
|
||||||
@@ -313,16 +321,16 @@ def test_mla_backend_selection(
|
|||||||
assert backend_path == expected_backend_path
|
assert backend_path == expected_backend_path
|
||||||
|
|
||||||
|
|
||||||
def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
def test_aiter_fa_requires_mi3xx(mock_vllm_config):
|
||||||
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
|
"""Test that ROCM_AITER_FA requires mi3xx architecture."""
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
# Mock on_gfx9 to return False
|
# Mock on_mi3xx to return False (used by supports_compute_capability)
|
||||||
with (
|
with (
|
||||||
patch("vllm.platforms.rocm.on_gfx9", return_value=False),
|
patch("vllm.platforms.rocm.on_mi3xx", return_value=False),
|
||||||
pytest.raises(
|
pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="only supported on gfx9",
|
match="compute capability not supported",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
attn_selector_config = AttentionSelectorConfig(
|
attn_selector_config = AttentionSelectorConfig(
|
||||||
@@ -342,11 +350,12 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
|
|||||||
|
|
||||||
|
|
||||||
def test_sparse_not_supported(mock_vllm_config):
|
def test_sparse_not_supported(mock_vllm_config):
|
||||||
"""Test that sparse attention is not supported on ROCm."""
|
"""Test that sparse MLA without use_mla flag raises an error."""
|
||||||
from vllm.platforms.rocm import RocmPlatform
|
from vllm.platforms.rocm import RocmPlatform
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
|
ValueError,
|
||||||
|
match="No valid attention backend found",
|
||||||
):
|
):
|
||||||
attn_selector_config = AttentionSelectorConfig(
|
attn_selector_config = AttentionSelectorConfig(
|
||||||
head_size=128,
|
head_size=128,
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ class AiterMLABackend(MLACommonBackend):
|
|||||||
"fp8_e5m2",
|
"fp8_e5m2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
return [1]
|
return [1]
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from vllm.platforms.interface import DeviceCapability
|
|||||||
from vllm.v1.attention.backend import (
|
from vllm.v1.attention.backend import (
|
||||||
AttentionLayer,
|
AttentionLayer,
|
||||||
AttentionType,
|
AttentionType,
|
||||||
|
MultipleOf,
|
||||||
is_quantized_kv_cache,
|
is_quantized_kv_cache,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||||
@@ -33,6 +34,20 @@ class TritonMLABackend(MLACommonBackend):
|
|||||||
"bfloat16",
|
"bfloat16",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||||
|
if block_size is None:
|
||||||
|
return True
|
||||||
|
return block_size % 16 == 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "TRITON_MLA"
|
return "TRITON_MLA"
|
||||||
|
|||||||
@@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
|||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
return [MultipleOf(16)]
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||||
|
if block_size is None:
|
||||||
|
return True
|
||||||
|
return block_size % 16 == 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_head_size(cls, head_size: int) -> bool:
|
def supports_head_size(cls, head_size: int) -> bool:
|
||||||
return head_size >= 32
|
return head_size >= 32
|
||||||
|
|||||||
@@ -188,6 +188,12 @@ class RocmAttentionBackend(AttentionBackend):
|
|||||||
# uses our optimized kernel logic.
|
# uses our optimized kernel logic.
|
||||||
return [16, 32, 544]
|
return [16, 32, 544]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||||
|
if block_size is None:
|
||||||
|
return True
|
||||||
|
return block_size in (16, 32, 544)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_head_sizes(cls) -> list[int]:
|
def get_supported_head_sizes(cls) -> list[int]:
|
||||||
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
return [32, 64, 80, 96, 128, 160, 192, 224, 256]
|
||||||
|
|||||||
@@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||||
return [MultipleOf(16)]
|
return [MultipleOf(16)]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||||
|
if block_size is None:
|
||||||
|
return True
|
||||||
|
return block_size % 16 == 0
|
||||||
|
|
||||||
forward_includes_kv_cache_update: bool = False
|
forward_includes_kv_cache_update: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user