[Kernel] Use flash-attn for decoding (#3648)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Stephen Krider
2024-05-13 15:50:33 -07:00
committed by GitHub
parent ce532ff45c
commit 1356df53bd
6 changed files with 313 additions and 65 deletions

View File

@@ -93,6 +93,20 @@ def _which_attn_to_use(
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
return _Backend.XFORMERS
if block_size % 16 != 0:
logger.info("Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
return _Backend.XFORMERS
if sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
return _Backend.XFORMERS
try:
import vllm_flash_attn # noqa: F401
except ImportError: