[V1] Make v1 more testable (#9888)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
|
||||
return forced_attn_backend
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
@@ -99,6 +98,31 @@ def get_attn_backend(
|
||||
is_blocksparse: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
from vllm.attention.backends.blocksparse_attn import (
|
||||
@@ -106,7 +130,7 @@ def get_attn_backend(
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||
is_attention_free)
|
||||
is_attention_free, use_v1)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using Flash Attention backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
@@ -162,13 +186,12 @@ def get_attn_backend(
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def which_attn_to_use(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
) -> _Backend:
|
||||
def which_attn_to_use(head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
@@ -228,7 +251,7 @@ def which_attn_to_use(
|
||||
if current_platform.is_hpu():
|
||||
return _Backend.HPU_ATTN
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if use_v1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
|
||||
Reference in New Issue
Block a user