[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Harry Huang
2026-01-24 01:56:48 +08:00
committed by GitHub
parent fec9da0af4
commit 5206e5e28c
42 changed files with 1774 additions and 128 deletions

View File

@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config

View File

@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled:
if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)

View File

@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
if is_mamba_cache_all:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
return_intermediate_states=is_mamba_cache_all,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
if is_mamba_cache_all:
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)

View File

@@ -1,6 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeAlias
import torch
from vllm.config.cache import MambaDType
@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
conv_state_k_shape,
recurrent_state_shape,
)
@dataclass
class MambaCopySpec:
"""
Data class specifying the memory-copy parameters for Mamba states used for
prefix caching in align mode.
Attributes:
start_addr (int): Starting address for the memory copy operation.
num_elements (int): Number of elements to copy from the starting address.
"""
start_addr: int
num_elements: int
MambaStateCopyFunc: TypeAlias = Callable[
[torch.Tensor, list[int], int, int], MambaCopySpec
]
"""
Type alias for a function that computes a MambaCopySpec for copying state slices.
Parameters:
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
block_ids: list[int] - the list of block indices for the state to copy.
cur_block_idx: int - current block index within `block_ids` to copy from.
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
Range: 1 .. 1 + num_speculative_tokens (inclusive).
"""
def get_conv_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a convolutional state slice."""
src_block_id = block_ids[cur_block_idx]
src_state = state[src_block_id, num_accepted_tokens - 1 :]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
def get_temporal_copy_spec(
state: torch.Tensor,
block_ids: list[int],
cur_block_idx: int,
num_accepted_tokens: int,
) -> MambaCopySpec:
"""Return a MambaCopySpec for copying a temporal state slice."""
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
src_state = state[src_block_id]
return MambaCopySpec(
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
)
get_full_copy_spec = get_temporal_copy_spec
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):
return (get_temporal_copy_spec,)
@classmethod
def mamba1_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def mamba2_state_copy_func(cls):
return get_conv_copy_spec, get_temporal_copy_spec
@classmethod
def short_conv_state_copy_func(cls):
return (get_conv_copy_spec,)
@classmethod
def gated_delta_net_state_copy_func(cls):
return (get_conv_copy_spec, get_temporal_copy_spec)
@classmethod
def kda_state_copy_func(cls):
return (
get_conv_copy_spec,
get_conv_copy_spec,
get_conv_copy_spec,
get_temporal_copy_spec,
)