[Kernel] Support sliding window in flash attention backend (#9403)
This commit is contained in:
@@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
@@ -105,8 +104,8 @@ def get_attn_backend(
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = which_attn_to_use(head_size, sliding_window, dtype,
|
||||
kv_cache_dtype, block_size, is_attention_free)
|
||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||
is_attention_free)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
@@ -155,7 +154,6 @@ def get_attn_backend(
|
||||
|
||||
def which_attn_to_use(
|
||||
head_size: int,
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
@@ -243,10 +241,6 @@ def which_attn_to_use(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif sliding_window is not None:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend due to sliding window.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
# FlashAttn is valid for the model, checking if the package is installed.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
|
||||
Reference in New Issue
Block a user