[ROCm] Refactor ROCm attention backend selection logic (#35246)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -211,6 +211,6 @@ configuration.
|
||||
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
||||
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
||||
| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_MLA_SPARSE` | bf16 | `auto` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
||||
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
||||
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
||||
|
||||
@@ -103,21 +103,20 @@ def test_backend_selection(
|
||||
|
||||
if name == "TRITON_MLA" and block_size == 1:
|
||||
# TRITON_MLA doesn't support block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
with pytest.raises(ValueError):
|
||||
get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
assert f"The selected backend, {name}" in str(exc_info.value)
|
||||
else:
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
32, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "ROCM_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
@@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
use_sparse: bool,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if use_sparse:
|
||||
return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE]
|
||||
|
||||
if use_mla:
|
||||
if rocm_aiter_ops.is_mla_enabled():
|
||||
return [
|
||||
AttentionBackendEnum.ROCM_AITER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.ROCM_AITER_TRITON_MLA,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
]
|
||||
|
||||
backends = []
|
||||
|
||||
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
|
||||
|
||||
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
|
||||
backends.append(AttentionBackendEnum.ROCM_AITER_FA)
|
||||
|
||||
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.use_prefill_decode_attention
|
||||
):
|
||||
backends.append(AttentionBackendEnum.ROCM_ATTN)
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
backends.append(AttentionBackendEnum.TRITON_ATTN)
|
||||
return backends
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_name: str = "rocm"
|
||||
@@ -349,6 +395,39 @@ class RocmPlatform(Platform):
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._rocm_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
device_capability: DeviceCapability,
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
]:
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(
|
||||
attn_selector_config.use_mla,
|
||||
attn_selector_config.use_sparse,
|
||||
)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
backend_class = backend.get_class()
|
||||
invalid_reasons_i = backend_class.validate_configuration(
|
||||
device_capability=device_capability,
|
||||
**attn_selector_config._asdict(),
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons_i = ["ImportError"]
|
||||
if invalid_reasons_i:
|
||||
invalid_reasons[backend] = invalid_reasons_i
|
||||
else:
|
||||
valid_backends_priorities.append((backend, priority))
|
||||
|
||||
return valid_backends_priorities, invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
@@ -356,117 +435,70 @@ class RocmPlatform(Platform):
|
||||
attn_selector_config: "AttentionSelectorConfig",
|
||||
num_heads: int | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
|
||||
block_size = attn_selector_config.block_size
|
||||
kv_cache_dtype = attn_selector_config.kv_cache_dtype
|
||||
|
||||
if attn_selector_config.use_sparse:
|
||||
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
|
||||
# First try checking just the selected backend, if there is one.
|
||||
if selected_backend is not None:
|
||||
try:
|
||||
backend_class = selected_backend.get_class()
|
||||
invalid_reasons = backend_class.validate_configuration(
|
||||
device_capability=device_capability,
|
||||
**attn_selector_config._asdict(),
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons = ["ImportError"]
|
||||
if invalid_reasons:
|
||||
raise ValueError(
|
||||
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
|
||||
f"Selected backend {selected_backend} is not valid for "
|
||||
f"this configuration. Reason: {invalid_reasons}"
|
||||
)
|
||||
assert block_size == 1, (
|
||||
"Sparse MLA backend on ROCm only supports block size 1 for now."
|
||||
)
|
||||
logger.info_once("Using Sparse MLA backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
|
||||
|
||||
if attn_selector_config.use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
AttentionBackendEnum.ROCM_AITER_MLA
|
||||
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
|
||||
else AttentionBackendEnum.TRITON_MLA
|
||||
)
|
||||
if selected_backend == AttentionBackendEnum.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info_once("Using Triton MLA backend.")
|
||||
return AttentionBackendEnum.TRITON_MLA.get_path()
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
)
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
|
||||
logger.info("Using AITER MLA backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_TRITON_MLA:
|
||||
logger.info("Using AITER TRITON MLA backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path()
|
||||
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend."
|
||||
)
|
||||
|
||||
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return AttentionBackendEnum.FLEX_ATTENTION.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
if on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected backend, {selected_backend.name}, "
|
||||
"is only supported on gfx9 architectures."
|
||||
)
|
||||
logger.info("Using %s backend.", selected_backend)
|
||||
return selected_backend.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Handle automatic backend selection based on environment variables
|
||||
if selected_backend is None:
|
||||
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.use_prefill_decode_attention
|
||||
):
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
# Priority 4: Check for AITER enabled without specific flags
|
||||
# This defaults to AITER FA only if MHA is not explicitly disabled
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and on_gfx9()
|
||||
and envs.VLLM_ROCM_USE_AITER_MHA is not False
|
||||
):
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
raise RuntimeError(
|
||||
f"Attention backend {selected_backend.name} is not supported on "
|
||||
"ROCm. Note that V0 attention backends have been removed."
|
||||
# No selected backend or the selected backend is invalid,
|
||||
# so we try finding a valid backend.
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
device_capability=device_capability,
|
||||
attn_selector_config=attn_selector_config,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
reasons_str = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"{backend.name}: [{', '.join(reasons)}]"
|
||||
for backend, reasons in invalid_reasons.items()
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
config_str = attn_selector_config.__repr__()
|
||||
logger.debug_once(
|
||||
f"Some attention backends are not valid for {cls.device_name} with "
|
||||
f"{config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
if len(valid_backends_priorities) == 0:
|
||||
raise ValueError(
|
||||
f"No valid attention backend found for {cls.device_name} "
|
||||
f"with {config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
|
||||
# We have found some valid backends. Select the one with the
|
||||
# highest priority.
|
||||
sorted_indices = sorted(
|
||||
range(len(valid_backends_priorities)),
|
||||
key=lambda i: valid_backends_priorities[i][1],
|
||||
)
|
||||
selected_index = sorted_indices[0]
|
||||
selected_backend = valid_backends_priorities[selected_index][0]
|
||||
logger.info_once(
|
||||
"Using %s attention backend out of potential backends: %s.",
|
||||
selected_backend.name,
|
||||
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
return selected_backend.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
|
||||
|
||||
@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
# The only supported block_size is 1
|
||||
return block_size is None or block_size == 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadata(AttentionMetadata):
|
||||
|
||||
@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
# The only unsupported block_size is 1
|
||||
return block_size is None or block_size != 1
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
# DeviceCapability is currently created using torch.cuda.get_device_capability()
|
||||
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
|
||||
# more reliable.
|
||||
return on_mi3xx()
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user