[ROCm][ViT] Enable Flash Attention Triton backend on RDNA3/RDNA4 (#32944)

Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
This commit is contained in:
monajafi-amd
2026-01-23 19:03:07 -07:00
committed by GitHub
parent ecc3dd66cc
commit 97ef11dd34

View File

@@ -163,6 +163,28 @@ def use_rocm_custom_paged_attention(
)
@cache
def flash_attn_triton_available() -> bool:
if not on_gfx1x():
return False
try:
from importlib.util import find_spec
if find_spec("flash_attn") is None:
return False
if find_spec("flash_attn.flash_attn_triton_amd") is None:
return False
if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
logger.info_once(
"Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
"Flash Attention Triton backend on RDNA."
)
return False
return True
except ImportError:
return False
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
@@ -348,7 +370,7 @@ class RocmPlatform(Platform):
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled():
if rocm_aiter_ops.is_enabled() and on_gfx9():
logger.info_once("Using AITER Flash Attention backend for ViT model.")
return AttentionBackendEnum.ROCM_AITER_FA
@@ -360,6 +382,17 @@ class RocmPlatform(Platform):
logger.info_once("Using Flash Attention backend for ViT model.")
return AttentionBackendEnum.FLASH_ATTN
# RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
if (
on_gfx1x()
and flash_attn_triton_available()
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
logger.info_once(
"Using Flash Attention (Triton backend) for ViT model on RDNA."
)
return AttentionBackendEnum.FLASH_ATTN
logger.info_once("Using Torch SDPA backend for ViT model.")
return AttentionBackendEnum.TORCH_SDPA