diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index e7ec8380e..f4f40baba 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -557,9 +557,21 @@ def test_causal_backend_correctness( if is_torch_equal_or_newer("2.9.0.dev0") else [] ) - SMALL_BLOCK_BACKENDS = [ - x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS - ] + + if current_platform.is_rocm(): + SMALL_BLOCK_BACKENDS = [ + x + for x in BACKENDS_TO_TEST + if ( + x not in LARGE_BLOCK_BACKENDS + and x is not AttentionBackendEnum.FLASH_ATTN + ) + ] + else: + SMALL_BLOCK_BACKENDS = [ + x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS + ] + _test_backend_correctness( batch_spec, model, @@ -580,12 +592,20 @@ def test_causal_backend_correctness( ) -SLIDING_WINDOW_BACKENDS_TO_TEST = [ - AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.FLEX_ATTENTION, - AttentionBackendEnum.TRITON_ATTN, - "FLEX_ATTENTION_SLOW", -] +if current_platform.is_rocm(): + # FLASH_ATTN is not supported on ROCm + SLIDING_WINDOW_BACKENDS_TO_TEST = [ + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, + "FLEX_ATTENTION_SLOW", + ] +else: + SLIDING_WINDOW_BACKENDS_TO_TEST = [ + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.FLEX_ATTENTION, + AttentionBackendEnum.TRITON_ATTN, + "FLEX_ATTENTION_SLOW", + ] @pytest.mark.parametrize( diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index d8c747056..77faeb93d 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -8,6 +8,7 @@ import pytest import torch from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.selector import AttentionSelectorConfig from vllm.platforms import current_platform # ROCm-specific attention backend selection tests @@ -144,8 +145,7 @@ def test_standard_attention_backend_selection( # Get the backend class path from vllm.platforms.rocm import RocmPlatform - backend_path = RocmPlatform.get_attn_backend_cls( - selected_backend=backend_enum, + attn_selector_config = AttentionSelectorConfig( head_size=128, dtype=torch.float16, kv_cache_dtype="auto", @@ -154,6 +154,11 @@ def test_standard_attention_backend_selection( has_sink=False, use_sparse=False, ) + + backend_path = RocmPlatform.get_attn_backend_cls( + selected_backend=backend_enum, attn_selector_config=attn_selector_config + ) + assert backend_path == expected_backend_path @@ -267,8 +272,7 @@ def test_mla_backend_selection( if should_raise: with pytest.raises(ValueError): - RocmPlatform.get_attn_backend_cls( - selected_backend=backend_enum, + attn_selector_config = AttentionSelectorConfig( head_size=128, dtype=torch.float16, kv_cache_dtype="auto", @@ -277,9 +281,22 @@ def test_mla_backend_selection( has_sink=False, use_sparse=False, ) + attn_selector_config = AttentionSelectorConfig( + head_size=128, + dtype=torch.float16, + kv_cache_dtype="auto", + block_size=block_size, + use_mla=True, + has_sink=False, + use_sparse=False, + ) + backend_path = RocmPlatform.get_attn_backend_cls( + selected_backend=backend_enum, + attn_selector_config=attn_selector_config, + ) + else: - backend_path = RocmPlatform.get_attn_backend_cls( - selected_backend=backend_enum, + attn_selector_config = AttentionSelectorConfig( head_size=128, dtype=torch.float16, kv_cache_dtype="auto", @@ -288,6 +305,11 @@ def test_mla_backend_selection( has_sink=False, use_sparse=False, ) + + backend_path = RocmPlatform.get_attn_backend_cls( + selected_backend=backend_enum, attn_selector_config=attn_selector_config + ) + assert backend_path == expected_backend_path @@ -303,8 +325,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): match="only supported on gfx9", ), ): - RocmPlatform.get_attn_backend_cls( - selected_backend=AttentionBackendEnum.ROCM_AITER_FA, + attn_selector_config = AttentionSelectorConfig( head_size=128, dtype=torch.float16, kv_cache_dtype="auto", @@ -314,6 +335,11 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): use_sparse=False, ) + RocmPlatform.get_attn_backend_cls( + selected_backend=AttentionBackendEnum.ROCM_AITER_FA, + attn_selector_config=attn_selector_config, + ) + def test_sparse_not_supported(mock_vllm_config): """Test that sparse attention is not supported on ROCm.""" @@ -322,8 +348,7 @@ def test_sparse_not_supported(mock_vllm_config): with pytest.raises( AssertionError, match="Sparse MLA backend on ROCm only supports block size 1" ): - RocmPlatform.get_attn_backend_cls( - selected_backend=None, + attn_selector_config = AttentionSelectorConfig( head_size=128, dtype=torch.float16, kv_cache_dtype="auto", @@ -332,3 +357,7 @@ def test_sparse_not_supported(mock_vllm_config): has_sink=False, use_sparse=True, ) + + RocmPlatform.get_attn_backend_cls( + selected_backend=None, attn_selector_config=attn_selector_config + ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 804934728..9b7c5822d 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -24,6 +24,7 @@ from vllm import _custom_ops as ops from vllm.attention.ops import flashmla from vllm.config import set_current_vllm_config from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.flashmla_sparse import ( FlashMLASparseBackend, @@ -125,6 +126,9 @@ def _quantize_dequantize_fp8_ds_mla( def test_sparse_backend_decode_correctness( dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init ): + if current_platform.is_rocm(): + pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.") + if not torch.cuda.is_available(): pytest.skip("CUDA is required for sparse MLA decode test")