[Misc] Refactor get_kv_cache_spec into AttentionLayerBase (#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-10-18 15:51:21 +02:00
committed by GitHub
parent ab4be40fc5
commit b26b70bec4
10 changed files with 151 additions and 118 deletions

View File

@@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase):
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
@abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = vllm_config.cache_config.mamba_block_size
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec(
shapes=self.get_state_shape(),
dtypes=self.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
),
)