[Kernel] Chunk-aligned mamba2 (#24683)

This commit is contained in:
Thomas Parnell
2025-09-29 23:18:25 +02:00
committed by GitHub
parent 61a3431613
commit fea3e476aa
8 changed files with 247 additions and 431 deletions

View File

@@ -502,9 +502,9 @@ class MambaMixer2(MambaBase, CustomOp):
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
query_start_loc_p = attn_metadata.query_start_loc_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@@ -634,9 +634,9 @@ class MambaMixer2(MambaBase, CustomOp):
z=None,
dt_bias=self.dt_bias,
seq_idx=seq_idx_p,
chunk_indices=chunk_indices_p,
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
dt_softplus=True,
dt_limit=(0.0, float("inf")),