[Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic (#26980)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -262,30 +262,64 @@ class RocmPlatform(Platform):
|
||||
f"is not MLA type while requested for MLA backend."
|
||||
)
|
||||
|
||||
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
if (
|
||||
rocm_aiter_ops.is_mha_enabled()
|
||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
if (
|
||||
rocm_aiter_ops.is_triton_unified_attn_enabled()
|
||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
if (
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
or selected_backend == AttentionBackendEnum.ROCM_ATTN
|
||||
):
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_ATTN:
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
if on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The selected backend, {selected_backend.name}, "
|
||||
"is only supported on gfx9 architectures."
|
||||
)
|
||||
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Handle automatic backend selection based on environment variables
|
||||
if selected_backend is None:
|
||||
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||
logger.info("Using Aiter Unified Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
|
||||
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
|
||||
logger.info("Using Rocm Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
|
||||
# Priority 4: Check for AITER enabled without specific flags
|
||||
# This defaults to AITER FA only if MHA is not explicitly disabled
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and on_gfx9()
|
||||
and envs.VLLM_ROCM_USE_AITER_MHA is not False
|
||||
):
|
||||
logger.info("Using Aiter Flash Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
|
||||
Reference in New Issue
Block a user