[Kernel] Use flash-attn for decoding (#3648)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
@@ -93,6 +93,20 @@ def _which_attn_to_use(
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
if block_size % 16 != 0:
|
||||
logger.info("Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
if sliding_window is not None:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend due to sliding window.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
try:
|
||||
import vllm_flash_attn # noqa: F401
|
||||
except ImportError:
|
||||
|
||||
Reference in New Issue
Block a user