[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_states = None
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, block_idx_last_computed_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
return_intermediate_states=is_mamba_cache_all,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
# The chunk_stride is the number of chunks per mamba block
|
||||
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||
# then chunk_stride = 2
|
||||
@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user