[RFC][ROCm][AITER] Keep all AITER kernels in _aiter_ops class like _custom_ops and _ipex_ops (#24490)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
vllmellm
2025-11-10 17:20:53 +01:00
committed by GitHub
parent 40e2eeeb92
commit f080a83511
25 changed files with 1193 additions and 924 deletions

View File

@@ -142,6 +142,8 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
from vllm._aiter_ops import rocm_aiter_ops
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@@ -157,7 +159,7 @@ def use_rocm_custom_paged_attention(
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)
and not (rocm_aiter_ops.is_pa_attn_enabled())
and sinks is None
)
@@ -202,12 +204,15 @@ class RocmPlatform(Platform):
]
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
@@ -228,19 +233,23 @@ class RocmPlatform(Platform):
has_sink,
use_sparse,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if use_mla:
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled,
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
if use_mla:
if selected_backend is None:
selected_backend = (
_Backend.ROCM_AITER_MLA
if is_aiter_mla_enabled() or block_size == 1
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
)
@@ -265,12 +274,12 @@ class RocmPlatform(Platform):
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return (