[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
cu_chunk_seqlen=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
|
||||
@@ -497,6 +497,8 @@ def selective_scan_fn(
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
cu_chunk_seqlen=None,
|
||||
last_chunk_indices=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
@@ -588,6 +590,8 @@ def selective_scan_fn(
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
cu_chunk_seqlen,
|
||||
last_chunk_indices,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
|
||||
Reference in New Issue
Block a user