[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=self.mamba_type,
|
||||
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config
|
||||
|
||||
@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_states = None
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, block_idx_last_computed_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
return_intermediate_states=is_mamba_cache_all,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
# The chunk_stride is the number of chunks per mamba block
|
||||
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||
# then chunk_stride = 2
|
||||
@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import MambaDType
|
||||
@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
|
||||
conv_state_k_shape,
|
||||
recurrent_state_shape,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCopySpec:
|
||||
"""
|
||||
Data class specifying the memory-copy parameters for Mamba states used for
|
||||
prefix caching in align mode.
|
||||
|
||||
Attributes:
|
||||
start_addr (int): Starting address for the memory copy operation.
|
||||
num_elements (int): Number of elements to copy from the starting address.
|
||||
"""
|
||||
|
||||
start_addr: int
|
||||
num_elements: int
|
||||
|
||||
|
||||
MambaStateCopyFunc: TypeAlias = Callable[
|
||||
[torch.Tensor, list[int], int, int], MambaCopySpec
|
||||
]
|
||||
"""
|
||||
Type alias for a function that computes a MambaCopySpec for copying state slices.
|
||||
Parameters:
|
||||
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
|
||||
block_ids: list[int] - the list of block indices for the state to copy.
|
||||
cur_block_idx: int - current block index within `block_ids` to copy from.
|
||||
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
|
||||
Range: 1 .. 1 + num_speculative_tokens (inclusive).
|
||||
"""
|
||||
|
||||
|
||||
def get_conv_copy_spec(
|
||||
state: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
cur_block_idx: int,
|
||||
num_accepted_tokens: int,
|
||||
) -> MambaCopySpec:
|
||||
"""Return a MambaCopySpec for copying a convolutional state slice."""
|
||||
src_block_id = block_ids[cur_block_idx]
|
||||
src_state = state[src_block_id, num_accepted_tokens - 1 :]
|
||||
return MambaCopySpec(
|
||||
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
|
||||
)
|
||||
|
||||
|
||||
def get_temporal_copy_spec(
|
||||
state: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
cur_block_idx: int,
|
||||
num_accepted_tokens: int,
|
||||
) -> MambaCopySpec:
|
||||
"""Return a MambaCopySpec for copying a temporal state slice."""
|
||||
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
|
||||
src_state = state[src_block_id]
|
||||
return MambaCopySpec(
|
||||
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
|
||||
)
|
||||
|
||||
|
||||
get_full_copy_spec = get_temporal_copy_spec
|
||||
|
||||
|
||||
class MambaStateCopyFuncCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_copy_func(cls):
|
||||
return (get_temporal_copy_spec,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_copy_func(cls):
|
||||
return (get_conv_copy_spec, get_temporal_copy_spec)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_copy_func(cls):
|
||||
return get_conv_copy_spec, get_temporal_copy_spec
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_copy_func(cls):
|
||||
return (get_conv_copy_spec,)
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_copy_func(cls):
|
||||
return (get_conv_copy_spec, get_temporal_copy_spec)
|
||||
|
||||
@classmethod
|
||||
def kda_state_copy_func(cls):
|
||||
return (
|
||||
get_conv_copy_spec,
|
||||
get_conv_copy_spec,
|
||||
get_conv_copy_spec,
|
||||
get_temporal_copy_spec,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user