[Kernel] Support sliding window in flash attention backend (#9403)

This commit is contained in:
Chen Zhang
2024-10-20 10:57:52 -07:00
committed by GitHub
parent 962d2c6349
commit 4fa3e33349
13 changed files with 41 additions and 61 deletions

View File

@@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None)
def get_attn_backend(
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
@@ -105,8 +104,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(head_size, sliding_window, dtype,
kv_cache_dtype, block_size, is_attention_free)
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
@@ -155,7 +154,6 @@ def get_attn_backend(
def which_attn_to_use(
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
@@ -243,10 +241,6 @@ def which_attn_to_use(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS
elif sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
selected_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN: