[Core] Refactor Attention Take 2 (#3462)

This commit is contained in:
Woosuk Kwon
2024-03-24 21:39:33 -07:00
committed by GitHub
parent b0dfa91dd7
commit 925f3332ca
47 changed files with 1268 additions and 1117 deletions

View File

@@ -0,0 +1,44 @@
from functools import lru_cache
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> AttentionBackend:
if _can_use_flash_attn(dtype):
logger.info("Using FlashAttention backend.")
from vllm.attention.backends.flash_attn import FlashAttentionBackend # noqa: F401
return FlashAttentionBackend
else:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import XFormersBackend # noqa: F401
return XFormersBackend
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
if is_hip():
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
"GPUs.")
return False
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return False
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found.")
return False
return True