[Bugfix][Kernel] Add head size check for attention backend selection (#4944)
This commit is contained in:
@@ -34,11 +34,21 @@ def get_attn_backend(
|
||||
sliding_window, dtype, kv_cache_dtype,
|
||||
block_size)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using FlashAttention-2 backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
elif backend == _Backend.XFORMERS:
|
||||
|
||||
# We check it here not in _which_attn_to_use because we cannot know
|
||||
# the head size until we import FlashAttentionBackend.
|
||||
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size in supported_head_sizes:
|
||||
logger.info("Using FlashAttention-2 backend.")
|
||||
return FlashAttentionBackend
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for head size %d. "
|
||||
"Using XFormers backend instead.", head_size)
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
if backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
XFormersBackend)
|
||||
|
||||
Reference in New Issue
Block a user