[ROCm] Refactor ROCm attention backend selection logic (#35246)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2026-03-05 08:51:26 -08:00
committed by GitHub
parent 3ee68590c7
commit 8c760b6ab6
6 changed files with 171 additions and 115 deletions

View File

@@ -211,6 +211,6 @@ configuration.
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | bf16 | `auto` | Any | 576 | ❌ | | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |

View File

@@ -103,21 +103,20 @@ def test_backend_selection(
if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError):
get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
assert f"The selected backend, {name}" in str(exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, block_size, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
else:
backend = get_attn_backend(
16, torch.float16, None, block_size, use_mla=use_mla
32, torch.float16, None, block_size, use_mla=use_mla
)
expected = "ROCM_ATTN"
assert backend.get_name() == expected

View File

@@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool:
return False
def _get_backend_priorities(
use_mla: bool,
use_sparse: bool,
) -> list[AttentionBackendEnum]:
from vllm._aiter_ops import rocm_aiter_ops
if use_sparse:
return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE]
if use_mla:
if rocm_aiter_ops.is_mla_enabled():
return [
AttentionBackendEnum.ROCM_AITER_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.ROCM_AITER_TRITON_MLA,
]
else:
return [
AttentionBackendEnum.TRITON_MLA,
]
backends = []
# Priority 1: Check for AITER Unified Attention (must check before MHA)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
# Priority 2: Check for AITER MHA (Flash Attention)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
backends.append(AttentionBackendEnum.ROCM_AITER_FA)
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.use_prefill_decode_attention
):
backends.append(AttentionBackendEnum.ROCM_ATTN)
# Default: Triton Unified Attention
backends.append(AttentionBackendEnum.TRITON_ATTN)
return backends
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
@@ -349,6 +395,39 @@ class RocmPlatform(Platform):
with contextlib.suppress(ImportError):
import vllm._rocm_C # noqa: F401
@classmethod
def get_valid_backends(
cls,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
]:
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
attn_selector_config.use_sparse,
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
else:
valid_backends_priorities.append((backend, priority))
return valid_backends_priorities, invalid_reasons
@classmethod
def get_attn_backend_cls(
cls,
@@ -356,117 +435,70 @@ class RocmPlatform(Platform):
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
device_capability = cls.get_device_capability()
assert device_capability is not None
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {invalid_reasons}"
)
assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
logger.info_once("Using Sparse MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if attn_selector_config.use_mla:
if selected_backend is None:
selected_backend = (
AttentionBackendEnum.ROCM_AITER_MLA
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else AttentionBackendEnum.TRITON_MLA
)
if selected_backend == AttentionBackendEnum.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend.")
return AttentionBackendEnum.TRITON_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
)
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
logger.info("Using AITER MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
logger.info("Using AITER TRITON MLA backend.")
return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend."
)
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
logger.info("Using Rocm Attention backend.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
if on_gfx9():
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
else:
raise ValueError(
f"The selected backend, {selected_backend.name}, "
"is only supported on gfx9 architectures."
)
logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
# Handle automatic backend selection based on environment variables
if selected_backend is None:
# Priority 1: Check for AITER Unified Attention (must check before MHA)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
logger.info("Using Aiter Unified Attention backend.")
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
# Priority 2: Check for AITER MHA (Flash Attention)
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from vllm.config import get_current_vllm_config_or_none
vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.use_prefill_decode_attention
):
logger.info("Using Rocm Attention backend.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
# Priority 4: Check for AITER enabled without specific flags
# This defaults to AITER FA only if MHA is not explicitly disabled
if (
envs.VLLM_ROCM_USE_AITER
and on_gfx9()
and envs.VLLM_ROCM_USE_AITER_MHA is not False
):
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Default: Triton Unified Attention
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
raise RuntimeError(
f"Attention backend {selected_backend.name} is not supported on "
"ROCm. Note that V0 attention backends have been removed."
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
if len(valid_backends_priorities) == 0:
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
# We have found some valid backends. Select the one with the
# highest priority.
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info_once(
"Using %s attention backend out of potential backends: %s.",
selected_backend.name,
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
scope="local",
)
return selected_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:

View File

@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
@staticmethod
def get_name() -> str:
@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16]
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
@classmethod
def is_mla(cls) -> bool:
return True
@classmethod
def is_sparse(cls) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only supported block_size is 1
return block_size is None or block_size == 1
@dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata):

View File

@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
# The only unsupported block_size is 1
return block_size is None or block_size != 1
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True

View File

@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@classmethod
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
from vllm.platforms.rocm import on_mi3xx
# DeviceCapability is currently created using torch.cuda.get_device_capability()
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
# more reliable.
return on_mi3xx()
class AiterFlashAttentionImpl(AttentionImpl):
def __init__(