[Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic (#26980)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-11-24 22:24:49 +07:00
committed by GitHub
parent 7a228b5305
commit e48b2e6848
2 changed files with 394 additions and 23 deletions

View File

@@ -262,30 +262,64 @@ class RocmPlatform(Platform):
f"is not MLA type while requested for MLA backend."
)
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
rocm_aiter_ops.is_mha_enabled()
) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
if (
rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
if (
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == AttentionBackendEnum.ROCM_ATTN
):
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger.info("Using Rocm Attention backend.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info("Using Triton Attention backend on V1 engine.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
logger.info("Using Rocm Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
# default case, using triton unified attention
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
if on_gfx9():
logger.info("Using Aiter Flash Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
else:
raise ValueError(
f"The selected backend, {selected_backend.name}, "
"is only supported on gfx9 architectures."
)
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
# Handle automatic backend selection based on environment variables
if selected_backend is None:
# Priority 1: Check for AITER Unified Attention (must check before MHA)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
logger.info("Using Aiter Unified Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
# Priority 2: Check for AITER MHA (Flash Attention)
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
logger.info("Using Aiter Flash Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
logger.info("Using Rocm Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_ATTN.get_path()
# Priority 4: Check for AITER enabled without specific flags
# This defaults to AITER FA only if MHA is not explicitly disabled
if (
envs.VLLM_ROCM_USE_AITER
and on_gfx9()
and envs.VLLM_ROCM_USE_AITER_MHA is not False
):
logger.info("Using Aiter Flash Attention backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
# Default: Triton Unified Attention
logger.info("Using Triton Attention backend on V1 engine.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
@classmethod
def set_device(cls, device: torch.device) -> None: