[Bugfix] Fix mamba2 prefill chunking (#23279)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
tomeras91
2025-09-08 14:42:41 +03:00
committed by GitHub
parent 5e537f45b4
commit e041314184
5 changed files with 349 additions and 35 deletions

View File

@@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),

View File

@@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) is_cont_batched to be all specified.
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
dA_cumsum,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])

View File

@@ -31,6 +31,8 @@ def _state_passing_fwd_kernel(
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
dim,
nchunks,
@@ -51,6 +53,7 @@ def _state_passing_fwd_kernel(
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
@@ -66,7 +69,8 @@ def _state_passing_fwd_kernel(
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
if HAS_INITSTATES:
@@ -95,35 +99,62 @@ def _state_passing_fwd_kernel(
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
prev_seq_idx_chunk_end = 0
logical_chunk_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
scale_mask = True
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(seq_idx_ptr +
min(c * chunk_size, seqlen) *
stride_seq_idx_seqlen)
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
# - load the chunk offset:
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
(c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0)
dA_cs -= dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
@@ -136,28 +167,36 @@ def _state_passing_fwd_kernel(
def _state_passing_fwd(
states,
dA_chunk_cumsum,
dA_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
chunk_offsets=None,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if chunk_size is None:
chunk_size = dA_cumsum.shape[-1]
else:
assert chunk_size == dA_cumsum.shape[-1]
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
@@ -173,13 +212,15 @@ def _state_passing_fwd(
states,
out,
final_states,
dA_chunk_cumsum,
dA_cumsum,
initial_states,
seq_idx,
chunk_offsets,
len(chunk_offsets) if chunk_offsets is not None else 0,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
chunk_size,
states.stride(0),
states.stride(1),
states.stride(2),
@@ -191,9 +232,10 @@ def _state_passing_fwd(
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),