[ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (#3643)

Co-authored-by: jpvillam <jpvillam@amd.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Juan Villamizar
2024-04-09 17:10:47 -05:00
committed by GitHub
parent e23a43aef8
commit 6c0b04515f
5 changed files with 1213 additions and 93 deletions

View File

@@ -1,3 +1,4 @@
import enum
from functools import lru_cache
from typing import Type
@@ -10,46 +11,68 @@ from vllm.utils import is_cpu, is_hip
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
if _can_use_flash_attn(dtype):
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif is_cpu():
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
else:
elif backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
elif backend == _Backend.ROCM_FLASH:
logger.info("Using ROCmFlashAttention backend.")
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
else:
raise ValueError("Invalid attention backend.")
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"""Returns which flash attention backend to use."""
if is_cpu():
return _Backend.TORCH_SDPA
if is_hip():
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if is_cpu():
return False
if torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs.")
return _Backend.ROCM_FLASH
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
"GPUs.")
return False
return _Backend.XFORMERS
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return False
return _Backend.XFORMERS
try:
import flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention because the package is not found. "
"Please install it for better performance.")
return False
return True
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.")
return _Backend.XFORMERS
return _Backend.FLASH_ATTN