[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Harry Huang
2026-01-24 01:56:48 +08:00
committed by GitHub
parent fec9da0af4
commit 5206e5e28c
42 changed files with 1774 additions and 128 deletions

View File

@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled:
if is_mamba_cache_all:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
if is_mamba_cache_all:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)