[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Harry Huang
2026-01-24 01:56:48 +08:00
committed by GitHub
parent fec9da0af4
commit 5206e5e28c
42 changed files with 1774 additions and 128 deletions

View File

@@ -31,6 +31,7 @@ CacheDType = Literal[
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32", "float16"]
MambaCacheMode = Literal["all", "align", "none"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
@@ -123,6 +124,15 @@ class CacheConfig:
"""The data type to use for the Mamba cache (ssm state only, conv state will
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
for the ssm state will be determined by mamba_cache_dtype."""
mamba_cache_mode: MambaCacheMode = "none"
"""The cache strategy for Mamba layers.
- "none": set when prefix caching is disabled.
- "all": cache the mamba state of all tokens at position i * block_size. This is
the default behavior (for models that support it) when prefix caching is
enabled.
- "align": only cache the mamba state of the last token of each scheduler step and
when the token is at position i * block_size.
"""
# Will be set after profiling.
num_gpu_blocks: int | None = field(default=None, init=False)

View File

@@ -999,6 +999,17 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser()

View File

@@ -60,6 +60,7 @@ from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaCacheMode,
MambaDType,
PrefixCachingHashAlgo,
)
@@ -556,6 +557,7 @@ class EngineArgs:
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
@@ -939,6 +941,9 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
cache_group.add_argument(
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
)
cache_group.add_argument(
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
)
@@ -1416,6 +1421,7 @@ class EngineArgs:
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
mamba_cache_mode=self.mamba_cache_mode,
kv_offloading_size=self.kv_offloading_size,
kv_offloading_backend=self.kv_offloading_backend,
)

View File

@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config

View File

@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled:
if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)

View File

@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
return_intermediate_states=is_mamba_cache_all,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)

View File

@@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
import torch
from vllm.config.cache import MambaDType
@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
conv_state_k_shape,
recurrent_state_shape,
)
@dataclass
class MambaCopySpec:
"""
Data class specifying the memory-copy parameters for Mamba states used for
prefix caching in align mode.
Attributes:
start_addr (int): Starting address for the memory copy operation.
num_elements (int): Number of elements to copy from the starting address.
"""
start_addr: int
num_elements: int
MambaStateCopyFunc: TypeAlias = Callable[
[torch.Tensor, list[int], int, int], MambaCopySpec
]
"""
Type alias for a function that computes a MambaCopySpec for copying state slices.
Parameters:
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
block_ids: list[int] - the list of block indices for the state to copy.
cur_block_idx: int - current block index within `block_ids` to copy from.
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
Range: 1 .. 1 + num_speculative_tokens (inclusive).
"""
def get_conv_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice."""
src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
def get_temporal_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a temporal state slice."""
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
src_state = state[src_block_id]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):
return (get_temporal_copy_spec,)
@classmethod
def mamba1_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def mamba2_state_copy_func(cls):
return get_conv_copy_spec, get_temporal_copy_spec
@classmethod
def short_conv_state_copy_func(cls):
return (get_conv_copy_spec,)
@classmethod
def gated_delta_net_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def kda_state_copy_func(cls):
return (
get_conv_copy_spec,
get_conv_copy_spec,
get_conv_copy_spec,
get_temporal_copy_spec,
)

View File

@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -455,6 +457,10 @@ class BambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config

View File

@@ -330,26 +330,54 @@ class MambaModelConfig(VerifyAndUpdateConfig):
cache_config = vllm_config.cache_config
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
if cache_config.mamba_cache_mode == "none":
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
logger.info(
"Hybrid or mamba-based model detected without "
"support for prefix caching: disabling."
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)
cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if (
cache_config.mamba_cache_mode == "all"
and not model_config.supports_mamba_prefix_caching
):
cache_config.mamba_cache_mode = "align"
logger.warning(
"Hybrid or mamba-based model detected without support "
"for prefix caching with Mamba cache 'all' mode: "
"falling back to 'align' mode."
)
if cache_config.mamba_cache_mode == "align":
assert vllm_config.scheduler_config.enable_chunked_prefill, (
"Chunked prefill is required for mamba cache mode 'align'."
)
assert not vllm_config.speculative_config, (
"Mamba cache mode 'align' is currently not compatible "
"with speculative decoding."
)
logger.info(
"Warning: Prefix caching in Mamba cache '%s' "
"mode is currently enabled. "
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe.",
cache_config.mamba_cache_mode,
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
if cache_config.mamba_cache_mode != "none":
cache_config.mamba_cache_mode = "none"
logger.warning(
"Mamba cache mode is set to 'none' when prefix caching is disabled"
)
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
block_size=-1, # block_size doesn't matter for mamba page size
).page_size_bytes
# Model may be marked as is_hybrid
@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0:
return
if cache_config.enable_prefix_caching:
if cache_config.mamba_cache_mode == "all":
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
attn_block_size,
)
# By default, mamba block size will be set to max_model_len.
# When enabling prefix caching and using align mamba cache
# mode, we align mamba block size to the block size as the
# basic granularity for prefix caching.
if cache_config.mamba_cache_mode == "align":
cache_config.mamba_block_size = cache_config.block_size
# compute new attention page size
attn_page_size = cache_config.block_size * attn_page_size_1_token

View File

@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -551,6 +553,10 @@ class FalconH1ForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config

View File

@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.collection_utils import common_prefix
from vllm.utils.func_utils import supports_kw
@@ -776,6 +777,19 @@ class IsHybrid(Protocol):
"""
...
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]:
"""Calculate copy-function callables for each Mamba state.
Returns:
A tuple of MambaStateCopyFunc callables that correspond, in order,
to the Mamba states produced by the model. Each callable accepts
(state, block_ids, cur_block_idx, num_accepted_tokens) and returns
a MambaCopySpec describing the memory-copy parameters for prefix
caching in align mode.
"""
...
@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...

View File

@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -558,6 +560,10 @@ class JambaForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -544,6 +546,14 @@ class KimiLinearForCausalLM(
num_spec=num_spec,
)
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -459,14 +461,19 @@ class Lfm2ForCausalLM(
conv_kernel=hf_config.conv_L_cache,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"Lfm2 currently does not support prefix caching"
)
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Lfm2 currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
super().__init__()
self.config = config

View File

@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM(
conv_kernel=hf_config.conv_L_cache,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

View File

@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -261,6 +263,10 @@ class MambaForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)

View File

@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -228,6 +230,10 @@ class Mamba2ForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config

View File

@@ -35,6 +35,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -1006,3 +1008,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()

View File

@@ -2128,3 +2128,7 @@ class NemotronH_Nano_VL_V2(
temp_vllm_config = copy.deepcopy(vllm_config)
temp_vllm_config.model_config.hf_config = text_config
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)
@classmethod
def get_mamba_state_copy_func(cls):
return NemotronHForCausalLM.get_mamba_state_copy_func()

View File

@@ -45,6 +45,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -809,6 +811,10 @@ class NemotronHForCausalLM(
conv_kernel=hf_config.conv_kernel,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config

View File

@@ -27,6 +27,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -899,6 +901,10 @@ class Plamo2ForCausalLM(
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@@ -48,6 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -1205,9 +1207,11 @@ class Qwen3NextForCausalLM(
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, (
"Qwen3Next currently does not support prefix caching"
)
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Qwen3Next currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self.quant_config = vllm_config.quant_config
super().__init__()
@@ -1278,6 +1282,10 @@ class Qwen3NextForCausalLM(
num_spec,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@@ -234,9 +234,11 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, (
"Qwen3NextMTP currently does not support prefix caching"
)
if cache_config.mamba_cache_mode == "all":
raise NotImplementedError(
"Qwen3NextMTP currently does not support 'all' prefix caching, "
"please use '--mamba-cache-mode=align' instead"
)
self.quant_config = vllm_config.quant_config

View File

@@ -32,6 +32,8 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
MambaStateCopyFuncCalculator,
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
@@ -891,6 +893,10 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
"""Initialize the Zamba2 model for causal language modeling.

View File

@@ -16,6 +16,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -158,6 +159,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
query_start_loc_cpu = m.query_start_loc_cpu
context_lens_tensor = m.compute_num_computed_tokens()
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
block_table_tensor = mamba_get_block_table_tensor(
m.block_table_tensor,
m.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
spec_sequence_masks_cpu: torch.Tensor | None = None
if (
@@ -189,7 +196,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
non_spec_state_indices_tensor = block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
non_spec_query_start_loc_cpu = query_start_loc_cpu
@@ -221,7 +228,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
@@ -235,10 +242,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = m.block_table_tensor[
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
non_spec_state_indices_tensor = block_table_tensor[
~spec_sequence_masks, 0
]

View File

@@ -11,7 +11,10 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.attention.backends.utils import (
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -61,7 +64,12 @@ class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMet
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(

View File

@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
compute_causal_conv1d_metadata,
mamba_get_block_table_tensor,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@@ -41,11 +42,15 @@ class BaseMambaAttentionMetadata:
state_indices_tensor: torch.Tensor
# The following tensors are only used for prefix caching and are None if disabled
# The following tensors are only used for prefix caching in all mode and
# are None if disabled
block_idx_last_scheduled_token: torch.Tensor | None
block_idx_first_scheduled_token_p: torch.Tensor | None
block_idx_last_computed_token: torch.Tensor | None
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
@@ -78,7 +83,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
@@ -198,7 +203,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
# for causal_conv1d
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
# Return a tensor of shape (#requests, #max blocks)
@@ -214,7 +219,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
else:
# Always return just a single block per each request:
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
state_indices_tensor = mamba_get_block_table_tensor(
common_attn_metadata.block_table_tensor,
common_attn_metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)[:, 0]
if num_prefills > 0:
if num_computed_tokens is None:
@@ -239,7 +249,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
)
)
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
assert num_computed_tokens is not None
num_computed_tokens_p = num_computed_tokens[
num_reqs - num_prefills : num_reqs
@@ -258,7 +268,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
if self.vllm_config.cache_config.mamba_cache_mode == "all":
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
@@ -286,6 +296,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
block_idx_last_computed_token=block_idx_last_computed_token,
num_computed_tokens_p=num_computed_tokens_p,
num_reqs=num_reqs,
seq_lens=common_attn_metadata.seq_lens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
@@ -298,8 +309,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
slot_mapping: torch.Tensor,
) -> M:
new_metadata = copy.copy(metadata)
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
state_indices_t = mamba_get_block_table_tensor(
blk_table,
metadata.seq_lens,
self.kv_cache_spec,
self.vllm_config.cache_config.mamba_cache_mode,
)
if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"):
# Only needs the block that saves the running state
state_indices_t = state_indices_t[:, 0]
num_reqs = blk_table.shape[0]
# For CUDA graphs, copy to persistent buffer

View File

@@ -17,6 +17,7 @@ from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -854,3 +855,40 @@ def extend_all_queries_by_1(
slot_mapping=new_slot_mapping,
)
return new_cad
def mamba_get_block_table_tensor(
block_table: torch.Tensor,
seq_lens: torch.Tensor,
kv_cache_spec: KVCacheSpec,
mamba_cache_mode: str,
) -> torch.Tensor:
"""
Get the block table tensor for mamba kernels from the input
common_attn_metadata.block_table_tensor given different mamba cache modes.
- "all": input (#requests, cdiv(max_model_len, block_size));
output (#requests, cdiv(max_model_len, block_size)).
- "none": input (#requests, 1 + num_speculative_blocks);
output (#requests, 1 + num_speculative_blocks).
- "align": input (#requests, cdiv(max_model_len, block_size));
output (#requests, 1 + num_speculative_blocks), which are the last
1 + num_speculative_blocks of each request.
"""
if mamba_cache_mode in ("all", "none"):
return block_table
else:
assert isinstance(kv_cache_spec, MambaSpec)
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
# to handle the invalid block table.
start_indices = torch.clamp(
(seq_lens - 1) // kv_cache_spec.block_size,
min=0,
)
offsets = torch.arange(
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device
)
indices_to_gather = start_indices.unsqueeze(1) + offsets
return torch.gather(block_table, 1, indices_to_gather)

View File

@@ -255,7 +255,8 @@ class BlockPool:
)
for i, blk in enumerate(new_full_blocks):
# Some blocks may be null blocks when enabling sparse attention like
# sliding window attention. We skip null blocks here.
# sliding window attention, or Mamba models with prefix-caching in
# align mode. We skip null blocks here.
if blk.is_null:
continue
assert blk.block_hash is None

View File

@@ -75,6 +75,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@@ -88,6 +89,9 @@ class KVCacheCoordinator(ABC):
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
total_computed_tokens: Include both local and external tokens.
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
Returns:
The number of blocks to allocate.
@@ -98,7 +102,7 @@ class KVCacheCoordinator(ABC):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, [], 0
request_id, num_encoder_tokens, [], 0, num_encoder_tokens
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
@@ -106,6 +110,7 @@ class KVCacheCoordinator(ABC):
num_tokens,
new_computed_blocks[i],
total_computed_tokens,
num_tokens_main_model,
)
return num_blocks_to_allocate
@@ -139,6 +144,7 @@ class KVCacheCoordinator(ABC):
self,
request_id: str,
num_tokens: int,
num_tokens_main_model: int,
num_encoder_tokens: int = 0,
) -> tuple[list[KVCacheBlock], ...]:
"""
@@ -149,6 +155,9 @@ class KVCacheCoordinator(ABC):
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
@@ -161,6 +170,7 @@ class KVCacheCoordinator(ABC):
num_encoder_tokens
if isinstance(manager, CrossAttentionManager)
else num_tokens,
num_tokens_main_model,
)
for manager in self.single_type_managers
)

View File

@@ -307,8 +307,9 @@ class KVCacheManager:
num_local_computed_tokens + num_external_computed_tokens,
self.max_model_len,
)
num_tokens_main_model = total_computed_tokens + num_new_tokens
num_tokens_need_slot = min(
total_computed_tokens + num_new_tokens + num_lookahead_tokens,
num_tokens_main_model + num_lookahead_tokens,
self.max_model_len,
)
@@ -329,6 +330,7 @@ class KVCacheManager:
num_encoder_tokens=num_encoder_tokens,
total_computed_tokens=num_local_computed_tokens
+ num_external_computed_tokens,
num_tokens_main_model=num_tokens_main_model,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
@@ -349,7 +351,10 @@ class KVCacheManager:
)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot, num_encoder_tokens
request.request_id,
num_tokens_need_slot,
num_tokens_main_model,
num_encoder_tokens,
)
# P/D: delay caching blocks if we have to recv from

View File

@@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import (
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import (
PrefixCacheStats,
@@ -226,6 +226,17 @@ class Scheduler(SchedulerInterface):
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any(
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
)
self.perf_metrics: ModelMetrics | None = None
if self.log_stats and vllm_config.observability_config.enable_mfu_metrics:
self.perf_metrics = ModelMetrics(vllm_config)
@@ -250,6 +261,53 @@ class Scheduler(SchedulerInterface):
vllm_config=self.vllm_config,
)
def _mamba_block_aligned_split(
self,
request: Request,
num_new_tokens: int,
num_new_local_computed_tokens: int = 0,
num_external_computed_tokens: int = 0,
) -> int:
assert num_external_computed_tokens == 0, (
"External KV connector is not verified yet"
)
# TODO: need check for resume requests
if request.num_output_tokens == 0: # prefill
# To enable block-aligned caching of the Mamba state, `num_new_tokens`
# must be a multiple of `block_size`.
# As an exception, if `num_new_tokens` is less than `block_size`, the
# state is simply not cached, requiring no special handling.
# Additionally, when Eagle mode is enabled, FullAttn prunes the last
# matching block. To prevent this from causing a Mamba cache miss, the
# last chunk must be larger than `block_size`.
block_size = self.cache_config.block_size
last_cache_position = (
request.num_prompt_tokens - request.num_prompt_tokens % block_size
)
# eagle prune
if self.use_eagle:
last_cache_position = max(last_cache_position - block_size, 0)
num_computed_tokens = (
request.num_computed_tokens
+ num_new_local_computed_tokens
+ num_external_computed_tokens
)
num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens
if num_computed_tokens_after_sched < last_cache_position:
# align to block_size
num_new_tokens = num_new_tokens // block_size * block_size
elif (
num_computed_tokens
< last_cache_position
< num_computed_tokens_after_sched
):
# force to cache the last chunk
num_new_tokens = last_cache_position - num_computed_tokens
else:
# prefill the last few tokens
pass
return num_new_tokens
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -340,6 +398,11 @@ class Scheduler(SchedulerInterface):
shift_computed_tokens=1 if self.use_eagle else 0,
)
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request, num_new_tokens
)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
@@ -350,6 +413,8 @@ class Scheduler(SchedulerInterface):
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# 4. Insufficient budget for a block-aligned chunk in hybrid
# models with mamba cache mode \"align\".
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
@@ -608,6 +673,16 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled.
break
if self.need_mamba_block_aligned_split:
num_new_tokens = self._mamba_block_aligned_split(
request,
num_new_tokens,
num_new_local_computed_tokens,
num_external_computed_tokens,
)
if num_new_tokens == 0:
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which

View File

@@ -66,12 +66,17 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
@classmethod
def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks)
def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@@ -84,6 +89,9 @@ class SingleTypeKVCacheManager(ABC):
prefix caching.
total_computed_tokens: Include both local and external computed
tokens.
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
Returns:
The number of blocks to allocate.
@@ -121,9 +129,8 @@ class SingleTypeKVCacheManager(ABC):
# If a computed block is an eviction candidate (in the free queue and
# ref_cnt == 0), it will be removed from the free queue when touched by
# the allocated request, so we must count it in the free-capacity check.
num_evictable_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks[num_skipped_new_computed_blocks:]
num_evictable_blocks = self._get_num_evictable_blocks(
new_computed_blocks[num_skipped_new_computed_blocks:]
)
return num_new_blocks + num_evictable_blocks
@@ -199,7 +206,7 @@ class SingleTypeKVCacheManager(ABC):
req_blocks.extend(allocated_blocks)
def allocate_new_blocks(
self, request_id: str, num_tokens: int
self, request_id: str, num_tokens: int, num_tokens_main_model: int
) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
@@ -209,7 +216,9 @@ class SingleTypeKVCacheManager(ABC):
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_tokens_main_model: The number of tokens for the main model (aka target
model in spec decode). w/o spec decode, it is num_tokens;
with spec decode, it is num_tokens - num_lookahead_tokens.
Returns:
The new allocated blocks.
"""
@@ -450,12 +459,9 @@ class FullAttentionManager(SingleTypeKVCacheManager):
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(
self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
@@ -586,12 +592,9 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(
self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
@@ -739,6 +742,17 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
class MambaManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode
self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode == "align":
# Mapping from request ID to the index of the block
# allocated in the previous step
self.last_state_block_idx: dict[str, int] = {}
# The set of the requests that have been allocated blocks
self._allocated_block_reqs: set[str] = set()
@classmethod
def find_longest_cache_hit(
cls,
@@ -787,6 +801,28 @@ class MambaManager(SingleTypeKVCacheManager):
return computed_blocks
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
assert isinstance(self.kv_cache_spec, MambaSpec)
super().remove_skipped_blocks(request_id, num_computed_tokens)
if self.mamba_cache_mode == "align":
# `last_state_block_idx` refers to the block index allocated two steps ago.
# The block allocated in the previous step is used to copy Mamba states
# into the block allocated in the current step; the earlier block is
# no longer needed and should be freed here.
last_state_block_idx = self.last_state_block_idx.get(request_id)
# Blocks allocated during prefill may be non-contiguous. Use
# `last_state_block_idx` to free the appropriate block and replace it
# with a null block.
if (
last_state_block_idx is not None
and last_state_block_idx
< cdiv(num_computed_tokens, self.block_size) - 1
):
blocks = self.req_to_blocks[request_id]
if blocks[last_state_block_idx] != self._null_block:
self.block_pool.free_blocks([blocks[last_state_block_idx]])
blocks[last_state_block_idx] = self._null_block
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
@@ -799,31 +835,134 @@ class MambaManager(SingleTypeKVCacheManager):
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
num_tokens_main_model: int,
) -> int:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size
* self.kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode != "align":
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
if self.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size * self.num_speculative_blocks
)
return super().get_num_blocks_to_allocate(
request_id,
num_tokens,
new_computed_blocks,
total_computed_tokens,
num_tokens_main_model,
)
return super().get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks, total_computed_tokens
)
else:
# We don't allocate blocks for lookahead tokens in align mode, because if
# x * block_size tokens are scheduled, num_tokens is
# x * block_size + num_lookahead_tokens and breaks the alignment.
# We can ignore lookahead tokens because current draft models don't have
# mamba layers.
num_tokens = num_tokens_main_model
num_required_blocks = (
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
)
num_new_blocks = (
num_required_blocks
- len(new_computed_blocks)
- len(self.req_to_blocks[request_id])
)
if num_new_blocks > 0:
if request_id in self._allocated_block_reqs:
# Old request. Needs at most 1 more blocks as we can reuse the
# speculative blocks in previous step.
num_new_blocks = 1
else:
# First prefill. Allocate 1 block for running state and the
# speculative blocks.
num_new_blocks = 1 + self.num_speculative_blocks
num_evictable_computed_blocks = self._get_num_evictable_blocks(
new_computed_blocks
)
return num_new_blocks + num_evictable_computed_blocks
def allocate_new_blocks(
self, request_id: str, num_tokens: int
self, request_id: str, num_tokens: int, num_tokens_main_model: int
) -> list[KVCacheBlock]:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size
* self.kv_cache_spec.num_speculative_blocks
if self.mamba_cache_mode != "align":
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
if self.num_speculative_blocks > 0:
num_tokens += self.block_size * self.num_speculative_blocks
return super().allocate_new_blocks(
request_id, num_tokens, num_tokens_main_model
)
return super().allocate_new_blocks(request_id, num_tokens)
else:
# We don't allocate blocks for lookahead tokens in align mode, because if
# x * block_size tokens are scheduled, num_tokens is
# x * block_size + num_lookahead_tokens and breaks the alignment.
# We can ignore lookahead tokens because current draft models don't have
# mamba layers.
num_tokens = num_tokens_main_model
req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id]
num_required_blocks = (
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
)
if num_required_blocks == len(req_blocks):
return []
else:
assert num_required_blocks > len(req_blocks), (
"num_required_blocks "
f"{num_required_blocks} < len(req_blocks) {len(req_blocks)}"
)
prev_block_len = len(req_blocks)
blocks_allocated = request_id in self._allocated_block_reqs
# Record the last state block
if blocks_allocated:
# We always save the running state at the last
# (1 + num_speculative_blocks) block
self.last_state_block_idx[request_id] = (
prev_block_len - 1 - self.num_speculative_blocks
)
elif prev_block_len > 0:
# When a new request hits the prefix cache, the last block
# saves the hit state.
self.last_state_block_idx[request_id] = prev_block_len - 1
num_skipped_blocks = (
num_required_blocks - self.num_speculative_blocks - 1
)
# null blocks
if prev_block_len < num_skipped_blocks:
req_blocks.extend(
[
self._null_block
for _ in range(prev_block_len, num_skipped_blocks)
]
)
if blocks_allocated:
# reuse previous speculative blocks in this step
for block_idx in range(
prev_block_len - self.num_speculative_blocks, prev_block_len
):
if block_idx < num_skipped_blocks:
req_blocks.append(req_blocks[block_idx])
req_blocks[block_idx] = self._null_block
else:
break
num_new_blocks = num_required_blocks - len(req_blocks)
if blocks_allocated:
assert num_new_blocks <= 1
else:
assert num_new_blocks <= self.num_speculative_blocks + 1
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
self._allocated_block_reqs.add(request_id)
return req_blocks[prev_block_len:]
def free(self, request_id: str) -> None:
if self.mamba_cache_mode == "align":
self._allocated_block_reqs.discard(request_id)
self.last_state_block_idx.pop(request_id, None)
super().free(request_id)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""

View File

@@ -276,6 +276,7 @@ class MambaSpec(KVCacheSpec):
dtypes: tuple[torch.dtype]
page_size_padded: int | None = None
mamba_type: str = "mamba2"
mamba_cache_mode: str = "none"
num_speculative_blocks: int = 0
@property
@@ -290,8 +291,13 @@ class MambaSpec(KVCacheSpec):
return page_size
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
if vllm_config.cache_config.mamba_cache_mode == "all":
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
elif vllm_config.cache_config.mamba_cache_mode == "align":
return self.page_size_bytes * (2 + self.num_speculative_blocks)
else:
return self.page_size_bytes * (1 + self.num_speculative_blocks)
@dataclass(frozen=True)

View File

@@ -8,6 +8,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__)
@@ -261,47 +262,45 @@ class MultiGroupBlockTable:
device: torch.device,
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
max_num_blocks: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
if len(kernel_block_sizes) != len(block_sizes):
raise ValueError(
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [
cdiv(max_model_len, block_size * total_cp_world_size)
for block_size in block_sizes
]
total_cp_world_size = dcp_world_size * pcp_world_size
if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) "
f"must match block_sizes length ({len(block_sizes)})"
)
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(
cdiv(max_model_len, block_size * total_cp_world_size),
1 + num_speculative_tokens,
),
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_block_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
block_sizes, kernel_block_sizes, max_num_blocks
)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:

View File

@@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, cast
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed import get_dcp_group, get_pcp_group
if TYPE_CHECKING:
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -40,3 +41,17 @@ def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
f"but the impl {layer_impl.__class__.__name__} "
"does not support PCP."
)
def get_total_cp_world_size():
try:
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# PCP might not be initialized in testing
pcp_world_size = 1
try:
dcp_world_size = get_dcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
return dcp_world_size * pcp_world_size

View File

@@ -89,11 +89,11 @@ class InputBatch:
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[int],
max_num_blocks_per_req: list[int] | None = None,
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
@@ -146,7 +146,7 @@ class InputBatch:
device=device,
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
max_num_blocks=max_num_blocks_per_req,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)

View File

@@ -152,7 +152,11 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.cp_utils import (
check_attention_cp_compatibility,
get_total_cp_world_size,
)
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -688,6 +692,7 @@ class GPUModelRunner(
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None
self.mamba_state_idx: dict[str, int] = {}
self.layerwise_nvtx_hooks_registered = False
def update_max_model_len(self, max_model_len: int) -> None:
@@ -1075,7 +1080,7 @@ class GPUModelRunner(
self.input_batch.refresh_metadata()
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None:
"""Update the cached states after model execution.
@@ -1111,6 +1116,16 @@ class GPUModelRunner(
)
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.postprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.input_batch,
self.requests,
self.mamba_state_idx,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
)
def _init_mrope_positions(self, req_state: CachedRequestState):
model = self.get_model()
@@ -2751,7 +2766,6 @@ class GPUModelRunner(
logits,
sampling_metadata,
)
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
return sampler_output
def _bookkeeping_sync(
@@ -3237,6 +3251,18 @@ class GPUModelRunner(
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.preprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.cache_config,
self.mamba_state_idx,
self.input_batch,
self.requests,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
@@ -3423,6 +3449,10 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
self._update_states_after_model_execute(
sampler_output.sampled_token_ids, scheduler_output
)
self._draft_token_ids = None
self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None
@@ -5322,6 +5352,24 @@ class GPUModelRunner(
for kv_cache_group in kv_cache_config.kv_cache_groups
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
max_num_blocks_per_req = cdiv(
max_model_len, block_sizes[i] * get_total_cp_world_size()
)
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = (
max_num_blocks_per_req
if self.cache_config.enable_prefix_caching
else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks_per_req = max(
max_num_blocks_per_req, mamba_blocks_per_req
)
max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
@@ -5333,18 +5381,18 @@ class GPUModelRunner(
)
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_model_len=max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
max_num_blocks_per_req=max_num_blocks,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=self.num_spec_tokens,
)
def _allocate_kv_cache_tensors(

View File

@@ -0,0 +1,232 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Any
import torch
import triton
import triton.language as tl
from vllm.config import CacheConfig
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
src_ptr = tl.load(src_ptrs + pid)
dst_ptr = tl.load(dst_ptrs + pid)
size = tl.load(sizes + pid)
offsets = tl.arange(0, BLOCK_SIZE)
for i in range(0, size, BLOCK_SIZE):
mask = (i + offsets) < size
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
data = tl.load(curr_src_ptr, mask=mask)
tl.store(curr_dst_ptr, data, mask=mask)
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch
grid = (batch,)
BLOCK_SIZE = 1024
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
mamba_group_ids: list[int] = []
mamba_specs: list[MambaSpec] = []
for i in range(len(kv_cache_config.kv_cache_groups)):
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
if isinstance(kv_cache_spec, MambaSpec):
mamba_group_ids.append(i)
mamba_specs.append(kv_cache_spec)
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
assert all(mamba_specs[0] == spec for spec in mamba_specs)
return mamba_group_ids, mamba_specs[0]
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
src_block_idx: int,
dest_block_idx: int,
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names:
attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
def preprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
mamba_state_idx: dict[str, int],
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
Copy the mamba state of previous step to the last
(1 + num_speculative_blocks) block.
"""
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
num_speculative_blocks = mamba_spec.num_speculative_blocks
# TODO(Chen): we need to optimize this function a lot
assert cache_config.enable_prefix_caching
block_size = mamba_spec.block_size
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids or set()
for req_id in itertools.chain(finished_req_ids, preempted_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
if prev_state_idx is None:
# new / resumed request, no previous state
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
num_blocks = len(req_state.block_ids[mamba_group_ids[0]])
# We always save the current running state at the last
# (1 + num_speculative_blocks) block.
# A corner case worth mention here: assume we have block_size = 4 and
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
# tokens [draft 1, draft 2]. Then we will have:
# Block 0: [A, B, C, draft 1]
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
# Block 2: speculative block
# Block 3: speculative block
# And use block 1 to save the running state.
curr_state_idx = num_blocks - 1 - num_speculative_blocks
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
prev_state_idx,
curr_state_idx,
input_batch.num_accepted_tokens_cpu[i] - 1,
req_state,
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
If a blocks is converted from partial block to full block in this step, copy the
state from the block for running state to the new full block.
"""
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
num_accepted_tokens = num_accepted_tokens_cpu[i]
num_tokens_running_state = (
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
)
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
aligned_new_computed_tokens = (
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
)
# TODO: how to ensure all blocks that cache_blocks called are cached here?
if aligned_new_computed_tokens >= num_tokens_running_state:
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
src_block_idx,
dest_block_idx,
accept_token_bias,
req_state,
forward_context,
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)