From 8c760b6ab6993c6a0d5f639747baefedb4612525 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 5 Mar 2026 08:51:26 -0800 Subject: [PATCH] [ROCm] Refactor ROCm attention backend selection logic (#35246) Signed-off-by: Sage Moore --- docs/design/attention_backends.md | 2 +- .../attention/test_attention_selector.py | 9 +- vllm/platforms/rocm.py | 242 ++++++++++-------- .../backends/mla/rocm_aiter_mla_sparse.py | 18 +- vllm/v1/attention/backends/mla/triton_mla.py | 5 + vllm/v1/attention/backends/rocm_aiter_fa.py | 10 + 6 files changed, 171 insertions(+), 115 deletions(-) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index e726d9925..7b643a46b 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -211,6 +211,6 @@ configuration. | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | | `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | bf16 | `auto` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 48582f4f6..6b6cae34f 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -103,21 +103,20 @@ def test_backend_selection( if name == "TRITON_MLA" and block_size == 1: # TRITON_MLA doesn't support block_size == 1 - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError): get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) - assert f"The selected backend, {name}" in str(exc_info.value) else: # Valid backend-block_size combination backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 576, torch.float16, None, block_size, use_mla=use_mla ) expected = name assert backend.get_name() == expected else: backend = get_attn_backend( - 16, torch.float16, None, block_size, use_mla=use_mla + 32, torch.float16, None, block_size, use_mla=use_mla ) expected = "ROCM_ATTN" assert backend.get_name() == expected diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 56d654961..b4925d085 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool: return False +def _get_backend_priorities( + use_mla: bool, + use_sparse: bool, +) -> list[AttentionBackendEnum]: + from vllm._aiter_ops import rocm_aiter_ops + + if use_sparse: + return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE] + + if use_mla: + if rocm_aiter_ops.is_mla_enabled(): + return [ + AttentionBackendEnum.ROCM_AITER_MLA, + AttentionBackendEnum.TRITON_MLA, + AttentionBackendEnum.ROCM_AITER_TRITON_MLA, + ] + else: + return [ + AttentionBackendEnum.TRITON_MLA, + ] + + backends = [] + + # Priority 1: Check for AITER Unified Attention (must check before MHA) + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: + backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN) + + # Priority 2: Check for AITER MHA (Flash Attention) + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA: + backends.append(AttentionBackendEnum.ROCM_AITER_FA) + + # Priority 3: Check for ROCM_ATTN (prefill-decode split) + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() + if ( + vllm_config is not None + and vllm_config.attention_config.use_prefill_decode_attention + ): + backends.append(AttentionBackendEnum.ROCM_ATTN) + + # Default: Triton Unified Attention + backends.append(AttentionBackendEnum.TRITON_ATTN) + return backends + + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM device_name: str = "rocm" @@ -349,6 +395,39 @@ class RocmPlatform(Platform): with contextlib.suppress(ImportError): import vllm._rocm_C # noqa: F401 + @classmethod + def get_valid_backends( + cls, + device_capability: DeviceCapability, + attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, + ) -> tuple[ + list[tuple["AttentionBackendEnum", int]], + dict["AttentionBackendEnum", list[str]], + ]: + valid_backends_priorities = [] + invalid_reasons = {} + + backend_priorities = _get_backend_priorities( + attn_selector_config.use_mla, + attn_selector_config.use_sparse, + ) + for priority, backend in enumerate(backend_priorities): + try: + backend_class = backend.get_class() + invalid_reasons_i = backend_class.validate_configuration( + device_capability=device_capability, + **attn_selector_config._asdict(), + ) + except ImportError: + invalid_reasons_i = ["ImportError"] + if invalid_reasons_i: + invalid_reasons[backend] = invalid_reasons_i + else: + valid_backends_priorities.append((backend, priority)) + + return valid_backends_priorities, invalid_reasons + @classmethod def get_attn_backend_cls( cls, @@ -356,117 +435,70 @@ class RocmPlatform(Platform): attn_selector_config: "AttentionSelectorConfig", num_heads: int | None = None, ) -> str: - from vllm._aiter_ops import rocm_aiter_ops + device_capability = cls.get_device_capability() + assert device_capability is not None - block_size = attn_selector_config.block_size - kv_cache_dtype = attn_selector_config.kv_cache_dtype - - if attn_selector_config.use_sparse: - if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): + # First try checking just the selected backend, if there is one. + if selected_backend is not None: + try: + backend_class = selected_backend.get_class() + invalid_reasons = backend_class.validate_configuration( + device_capability=device_capability, + **attn_selector_config._asdict(), + ) + except ImportError: + invalid_reasons = ["ImportError"] + if invalid_reasons: raise ValueError( - "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {invalid_reasons}" ) - assert block_size == 1, ( - "Sparse MLA backend on ROCm only supports block size 1 for now." - ) - logger.info_once("Using Sparse MLA backend.") - return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() - - if attn_selector_config.use_mla: - if selected_backend is None: - selected_backend = ( - AttentionBackendEnum.ROCM_AITER_MLA - if rocm_aiter_ops.is_mla_enabled() or block_size == 1 - else AttentionBackendEnum.TRITON_MLA - ) - if selected_backend == AttentionBackendEnum.TRITON_MLA: - if block_size != 1: - logger.info_once("Using Triton MLA backend.") - return AttentionBackendEnum.TRITON_MLA.get_path() - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}." - ) - if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA: - logger.info("Using AITER MLA backend.") - return AttentionBackendEnum.ROCM_AITER_MLA.get_path() - if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA: - logger.info("Using AITER TRITON MLA backend.") - return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path() - - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"is not MLA type while requested for MLA backend." - ) - - if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: - logger.info("Using FlexAttention backend.") - return AttentionBackendEnum.FLEX_ATTENTION.get_path() - - if selected_backend == AttentionBackendEnum.TRITON_ATTN: - logger.info("Using Triton Attention backend.") - return AttentionBackendEnum.TRITON_ATTN.get_path() - - if selected_backend == AttentionBackendEnum.ROCM_ATTN: - logger.info("Using Rocm Attention backend.") - return AttentionBackendEnum.ROCM_ATTN.get_path() - - if selected_backend == AttentionBackendEnum.ROCM_AITER_FA: - if on_gfx9(): - logger.info("Using Aiter Flash Attention backend.") - return AttentionBackendEnum.ROCM_AITER_FA.get_path() else: - raise ValueError( - f"The selected backend, {selected_backend.name}, " - "is only supported on gfx9 architectures." - ) + logger.info("Using %s backend.", selected_backend) + return selected_backend.get_path() - if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: - logger.info("Using Aiter Unified Attention backend.") - return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() - - # Handle automatic backend selection based on environment variables - if selected_backend is None: - # Priority 1: Check for AITER Unified Attention (must check before MHA) - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: - logger.info("Using Aiter Unified Attention backend.") - return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() - - # Priority 2: Check for AITER MHA (Flash Attention) - # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1) - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - logger.info("Using Aiter Flash Attention backend.") - return AttentionBackendEnum.ROCM_AITER_FA.get_path() - - # Priority 3: Check for ROCM_ATTN (prefill-decode split) - from vllm.config import get_current_vllm_config_or_none - - vllm_config = get_current_vllm_config_or_none() - if ( - vllm_config is not None - and vllm_config.attention_config.use_prefill_decode_attention - ): - logger.info("Using Rocm Attention backend.") - return AttentionBackendEnum.ROCM_ATTN.get_path() - - # Priority 4: Check for AITER enabled without specific flags - # This defaults to AITER FA only if MHA is not explicitly disabled - if ( - envs.VLLM_ROCM_USE_AITER - and on_gfx9() - and envs.VLLM_ROCM_USE_AITER_MHA is not False - ): - logger.info("Using Aiter Flash Attention backend.") - return AttentionBackendEnum.ROCM_AITER_FA.get_path() - - # Default: Triton Unified Attention - logger.info("Using Triton Attention backend.") - return AttentionBackendEnum.TRITON_ATTN.get_path() - - raise RuntimeError( - f"Attention backend {selected_backend.name} is not supported on " - "ROCm. Note that V0 attention backends have been removed." + # No selected backend or the selected backend is invalid, + # so we try finding a valid backend. + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + device_capability=device_capability, + attn_selector_config=attn_selector_config, + num_heads=num_heads, ) + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" + ) + config_str = attn_selector_config.__repr__() + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + if len(valid_backends_priorities) == 0: + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." + ) + + # We have found some valid backends. Select the one with the + # highest priority. + sorted_indices = sorted( + range(len(valid_backends_priorities)), + key=lambda i: valid_backends_priorities[i][1], + ) + selected_index = sorted_indices[0] + selected_backend = valid_backends_priorities[selected_index][0] + logger.info_once( + "Using %s attention backend out of potential backends: %s.", + selected_backend.name, + "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]", + scope="local", + ) + + return selected_backend.get_path() @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index c8aafae8d..47f1c06ea 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton( class ROCMAiterMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] @staticmethod def get_name() -> str: @@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend): ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) - @classmethod - def get_supported_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16] - @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] + @classmethod + def is_mla(cls) -> bool: + return True + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + # The only supported block_size is 1 + return block_size is None or block_size == 1 + @dataclass class ROCMAiterMLASparseMetadata(AttentionMetadata): diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f6c1790f6..a950288b6 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend): def supports_compute_capability(cls, capability: DeviceCapability) -> bool: return True + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + # The only unsupported block_size is 1 + return block_size is None or block_size != 1 + class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index bc547585b..41147ca63 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( @@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend): raise ValueError("Block size must be a multiple of 16.") return (2, num_blocks, block_size, num_kv_heads, head_size) + @classmethod + def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + from vllm.platforms.rocm import on_mi3xx + + # DeviceCapability is currently created using torch.cuda.get_device_capability() + # which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is + # more reliable. + return on_mi3xx() + class AiterFlashAttentionImpl(AttentionImpl): def __init__(