[Kernel] Add flash-attn back (#4907)

This commit is contained in:
Woosuk Kwon
2024-05-19 18:11:30 -07:00
committed by GitHub
parent 27ce85476e
commit b57e6c5949
6 changed files with 304 additions and 61 deletions

View File

@@ -93,6 +93,20 @@ def _which_attn_to_use(
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
return _Backend.XFORMERS
if block_size % 16 != 0:
logger.info("Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
return _Backend.XFORMERS
if sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
return _Backend.XFORMERS
try:
import vllm_flash_attn # noqa: F401
except ImportError: