[V0 Deprecation] Remove placeholder attn (#25510)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-09-24 00:12:14 +02:00
committed by GitHub
parent 4f8c4b890a
commit 969b4da3a6
5 changed files with 10 additions and 354 deletions

View File

@@ -142,7 +142,6 @@ def get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
@@ -156,7 +155,6 @@ def get_attn_backend(
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
@@ -169,17 +167,10 @@ def _cached_get_attn_backend(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionBackend)
return PlaceholderAttentionBackend
# Check whether a particular choice of backend was
# previously forced.