[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
Reference in New Issue
Block a user