[Bugfix] Fix mamba2 prefill chunking (#23279)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
device='cuda'):
|
||||
device='cuda',
|
||||
return_naive_ref=True):
|
||||
|
||||
# this function generates a random examples of certain length
|
||||
# and then cut according to "example_lens_by_batch" and feed
|
||||
# them in continuous batches to the kernels
|
||||
# them in continuous batches to the kernels.
|
||||
# If if return_naive_ref=True, the naive torch implementation
|
||||
# ssd_minimal_discrete will be used to compute and return
|
||||
# reference output.
|
||||
|
||||
# generate the full-length example
|
||||
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
|
||||
d_head, itype)
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
|
||||
A * dt,
|
||||
B,
|
||||
C,
|
||||
block_len=full_length // 4)
|
||||
if return_naive_ref:
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
|
||||
A * dt,
|
||||
B,
|
||||
C,
|
||||
block_len=full_length //
|
||||
4)
|
||||
|
||||
# internal function that outputs a cont batch of examples
|
||||
# given a tuple of lengths for each example in the batch
|
||||
@@ -179,7 +185,8 @@ def generate_continuous_batched_examples(example_lens_by_batch,
|
||||
IND_S = [x % full_length for x in IND_E]
|
||||
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
|
||||
|
||||
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
|
||||
yield ([Y_min[s, IND_S[s]:IND_E[s]]
|
||||
for s in range(num_examples)] if return_naive_ref else None,
|
||||
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
|
||||
|
||||
|
||||
@@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
if clear:
|
||||
states[i].fill_(0.)
|
||||
exhausted[i] = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("chunk_size", [8, 256])
|
||||
@pytest.mark.parametrize("seqlens", [
|
||||
(16, 2, 8, 13),
|
||||
(270, 88, 212, 203),
|
||||
(16, 20),
|
||||
])
|
||||
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
|
||||
# This test verifies the correctness of the chunked prefill implementation
|
||||
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
|
||||
# dimension) of chunked results with the full sequence result.
|
||||
# It is different from test_mamba_chunk_scan_cont_batch by:
|
||||
# 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get
|
||||
# reference outputs. Instead, it compares chunked kernel outputs to full
|
||||
# sequence kernel outputs. This is the most straightforward way to
|
||||
# assert chunked prefill correctness.
|
||||
# 2. It focuses on cases where sequences change in the middle of mamba
|
||||
# chunks, and not necessarily on chunk boundaries.
|
||||
|
||||
max_seqlen = max(seqlens)
|
||||
# This test can have larger error for longer sequences
|
||||
if max_seqlen > 256:
|
||||
atol, rtol = 1e-2, 5e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
num_sequences = len(seqlens)
|
||||
n_heads = 16
|
||||
d_head = 64
|
||||
itype = torch.float32
|
||||
|
||||
# hold state during the cutting process so we know if an
|
||||
# example has been exhausted and needs to cycle
|
||||
last_taken: dict = {} # map: eg -> pointer to last taken sample
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
|
||||
generate_continuous_batched_examples([seqlens],
|
||||
num_sequences,
|
||||
max_seqlen,
|
||||
last_taken,
|
||||
exhausted,
|
||||
n_heads,
|
||||
d_head,
|
||||
itype,
|
||||
return_naive_ref=False))
|
||||
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
|
||||
device = X.device
|
||||
|
||||
## full seqlen computation
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=None,
|
||||
out=Y_ref,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
# first chunk
|
||||
chunked_seqlens = seqlens // 2
|
||||
chunked_cu_seqlens = torch.cat([
|
||||
torch.tensor([0], device=device),
|
||||
torch.cumsum(chunked_seqlens, dim=0)
|
||||
],
|
||||
dim=0)
|
||||
chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(chunked_seqlens), device=device),
|
||||
chunked_seqlens,
|
||||
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
|
||||
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
|
||||
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
|
||||
for i in range(num_sequences):
|
||||
# fmt: off
|
||||
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
|
||||
|
||||
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
|
||||
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
|
||||
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
|
||||
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined(
|
||||
X_chunked,
|
||||
dt_chunked,
|
||||
A,
|
||||
B_chunked,
|
||||
C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=chunked_cu_seqlens,
|
||||
seq_idx=chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=None,
|
||||
out=Y_partial,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
remaining_chunked_seqlens = seqlens - chunked_seqlens
|
||||
remaining_chunked_cu_seqlens = torch.cat([
|
||||
torch.tensor([0], device=device),
|
||||
torch.cumsum(remaining_chunked_seqlens, dim=0)
|
||||
],
|
||||
dim=0)
|
||||
remaining_chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
||||
remaining_chunked_seqlens,
|
||||
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
|
||||
torch.int32)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
for i in range(num_sequences):
|
||||
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
|
||||
|
||||
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
|
||||
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
|
||||
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
|
||||
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
|
||||
|
||||
# assert input chunking is correct
|
||||
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
|
||||
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
|
||||
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
|
||||
],
|
||||
dim=1)
|
||||
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
|
||||
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
|
||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
remaining_chunked_cu_seqlens,
|
||||
chunk_size,
|
||||
remaining_chunked_cu_seqlens[-1])
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined(
|
||||
remaining_X_chunked,
|
||||
remaining_dt_chunked,
|
||||
A,
|
||||
remaining_B_chunked,
|
||||
remaining_C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens,
|
||||
seq_idx=remaining_chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=partial_state,
|
||||
out=Y_chunked,
|
||||
)
|
||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||
|
||||
# kernel chunked is same as kernel overall
|
||||
for i in range(num_sequences):
|
||||
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
|
||||
torch.testing.assert_close(
|
||||
Y_seq[:, :chunked_seqlens[i], ...],
|
||||
Y_ref_seq[:, :chunked_seqlens[i], ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
|
||||
torch.testing.assert_close(
|
||||
Y_seq[:, chunked_seqlens[i]:, ...],
|
||||
Y_ref_seq[:, chunked_seqlens[i]:, ...],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
|
||||
|
||||
state_seq = state_chunked[i]
|
||||
state_seq_ref = state_ref[i]
|
||||
torch.testing.assert_close(
|
||||
state_seq,
|
||||
state_seq_ref,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg=lambda x: f"seq{i} state " + x) # noqa: B023
|
||||
|
||||
Reference in New Issue
Block a user