[Bugfix] Fix mamba2 prefill chunking (#23279)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
||||
# to increase pointer by pid_m (it is a constant offset,
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
|
||||
@@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states
|
||||
# ii) seq_idx and iii) is_cont_batched to be all specified.
|
||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states, final_states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum[:, :, :, -1],
|
||||
dA_cumsum,
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else None,
|
||||
seq_idx=seq_idx,
|
||||
chunk_size=chunk_size,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
is_cont_batched=cu_seqlens is not None)
|
||||
is_cont_batched=cu_seqlens is not None,
|
||||
chunk_offsets=chunk_offsets)
|
||||
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||
for t in [states, final_states])
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ def _state_passing_fwd_kernel(
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
dim,
|
||||
nchunks,
|
||||
@@ -51,6 +53,7 @@ def _state_passing_fwd_kernel(
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_initstates_batch,
|
||||
stride_initstates_head,
|
||||
stride_initstates_dim,
|
||||
@@ -66,7 +69,8 @@ def _state_passing_fwd_kernel(
|
||||
pid_h = tl.program_id(axis=2)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
|
||||
chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
||||
if HAS_INITSTATES:
|
||||
@@ -95,35 +99,62 @@ def _state_passing_fwd_kernel(
|
||||
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
seq_idx = 0
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale = tl.exp(dA_cs)
|
||||
scale_mask = True
|
||||
if HAS_SEQ_IDX:
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_new below.
|
||||
seq_idx_new = tl.load(seq_idx_ptr +
|
||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||
stride_seq_idx_seqlen)
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
|
||||
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
||||
if HAS_INITSTATES:
|
||||
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
|
||||
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs,
|
||||
mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
||||
|
||||
seq_idx = seq_idx_new
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
if c < nchunks - 1:
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
@@ -136,28 +167,36 @@ def _state_passing_fwd_kernel(
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_chunk_cumsum,
|
||||
dA_cumsum,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_size=None,
|
||||
out_dtype=None,
|
||||
is_cont_batched=False,
|
||||
chunk_offsets=None,
|
||||
):
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
||||
if chunk_size is None:
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
else:
|
||||
assert chunk_size == dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
if initial_states is not None:
|
||||
if is_cont_batched:
|
||||
# - if cu_seqlens is provided, then the initial states
|
||||
# are used for continuous batching. In which case we
|
||||
# require seq_idx to be provided
|
||||
assert seq_idx is not None, ""
|
||||
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
|
||||
# - we also need chunk_offsets to be provided, to account
|
||||
# for computation of dA_cumsum from the start of the
|
||||
# sequence
|
||||
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
|
||||
else:
|
||||
# - this is the regular batching case, where initial
|
||||
# states are used are for each example of the batch.
|
||||
assert initial_states.shape == (batch, nheads, dim)
|
||||
|
||||
if seq_idx is not None:
|
||||
assert chunk_size is not None
|
||||
seqlen = seq_idx.shape[-1]
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
@@ -173,13 +212,15 @@ def _state_passing_fwd(
|
||||
states,
|
||||
out,
|
||||
final_states,
|
||||
dA_chunk_cumsum,
|
||||
dA_cumsum,
|
||||
initial_states,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
len(chunk_offsets) if chunk_offsets is not None else 0,
|
||||
dim,
|
||||
nchunks,
|
||||
seqlen if seq_idx is not None else 0,
|
||||
chunk_size if seq_idx is not None else 0,
|
||||
chunk_size,
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
@@ -191,9 +232,10 @@ def _state_passing_fwd(
|
||||
final_states.stride(0),
|
||||
final_states.stride(1),
|
||||
final_states.stride(2),
|
||||
dA_chunk_cumsum.stride(0),
|
||||
dA_chunk_cumsum.stride(2),
|
||||
dA_chunk_cumsum.stride(1),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*((initial_states.stride(0), initial_states.stride(1),
|
||||
initial_states.stride(2)) if initial_states is not None else
|
||||
(0, 0, 0)),
|
||||
|
||||
Reference in New Issue
Block a user