[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
Asaf Gardin
2026-03-01 14:40:23 +02:00
committed by GitHub
parent da543d1abe
commit bbf81f9a92
11 changed files with 251 additions and 146 deletions

View File

@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
cu_chunk_seqlen=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
)
ssm_outputs.append(scan_out_p)