[V0 Deprecation] Deprecate BlockSparse Attention & Phi3-Small (#21217)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-07-19 13:53:17 -07:00
committed by GitHub
parent 881e3cbe3b
commit 752c6ade2e
38 changed files with 65 additions and 2435 deletions

View File

@@ -143,7 +143,6 @@ def get_attn_backend(
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
@@ -157,7 +156,6 @@ def get_attn_backend(
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
)
@@ -170,16 +168,9 @@ def _cached_get_attn_backend(
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
from vllm.attention.backends.blocksparse_attn import (
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free: