[ROCm][ViT] Enable Flash Attention Triton backend on RDNA3/RDNA4 (#32944)
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
This commit is contained in:
@@ -163,6 +163,28 @@ def use_rocm_custom_paged_attention(
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def flash_attn_triton_available() -> bool:
|
||||
if not on_gfx1x():
|
||||
return False
|
||||
try:
|
||||
from importlib.util import find_spec
|
||||
|
||||
if find_spec("flash_attn") is None:
|
||||
return False
|
||||
if find_spec("flash_attn.flash_attn_triton_amd") is None:
|
||||
return False
|
||||
if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
|
||||
logger.info_once(
|
||||
"Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
|
||||
"Flash Attention Triton backend on RDNA."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
device_name: str = "rocm"
|
||||
@@ -348,7 +370,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
if rocm_aiter_ops.is_enabled() and on_gfx9():
|
||||
logger.info_once("Using AITER Flash Attention backend for ViT model.")
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
@@ -360,6 +382,17 @@ class RocmPlatform(Platform):
|
||||
logger.info_once("Using Flash Attention backend for ViT model.")
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
# RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
|
||||
if (
|
||||
on_gfx1x()
|
||||
and flash_attn_triton_available()
|
||||
and (dtype == torch.float16 or dtype == torch.bfloat16)
|
||||
):
|
||||
logger.info_once(
|
||||
"Using Flash Attention (Triton backend) for ViT model on RDNA."
|
||||
)
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
logger.info_once("Using Torch SDPA backend for ViT model.")
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
|
||||
Reference in New Issue
Block a user