[Encoder Decoder] Add flash_attn kernel support for encoder-decoder models (#9559)

This commit is contained in:
sroy745
2024-11-01 23:22:49 -07:00
committed by GitHub
parent d522034c85
commit a78dd3303e
11 changed files with 715 additions and 316 deletions

View File

@@ -98,7 +98,6 @@ def get_attn_backend(
is_blocksparse: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
@@ -108,6 +107,7 @@ def get_attn_backend(
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
if backend == _Backend.FLASH_ATTN:
logger.info("Using Flash Attention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend