Fix test_mamba_ssm_ssd.py due to missing _query_start_loc_to_chunk_indices_offsets (#25995)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
compute_varlen_chunk_metadata)
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
@@ -225,13 +225,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||
B, C, chunk_size)
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device='cuda').cumsum(dim=0)
|
||||
seq_idx = torch.zeros(seqlen, dtype=torch.int32, device=cu_seqlens.device)
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
cu_seqlens = torch.tensor((0, seqlen), device="cuda").cumsum(dim=0)
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||
# varlen has implicit batch=1
|
||||
X = X.squeeze(0)
|
||||
dt = dt.squeeze(0)
|
||||
@@ -239,18 +235,20 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
B = B.squeeze(0)
|
||||
C = C.squeeze(0)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined_varlen(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
out=Y)
|
||||
final_state = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[-1], Y_min[0, -1], atol=atol, rtol=rtol)
|
||||
@@ -312,14 +310,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
|
||||
|
||||
states = None
|
||||
for Y_min, cu_seqlens, seq_idx, (
|
||||
for Y_min, cu_seqlens, _token_seq_idx, (
|
||||
A, dt, X, B, C) in generate_continuous_batched_examples(
|
||||
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
|
||||
d_head, itype):
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined_varlen(
|
||||
@@ -329,13 +326,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=states,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y,
|
||||
D=None,
|
||||
initial_states=states,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
@@ -403,9 +400,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
device = X.device
|
||||
|
||||
## full seqlen computation
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(cu_seqlens, chunk_size))
|
||||
Y_ref = torch.empty_like(X)
|
||||
state_ref = mamba_chunk_scan_combined_varlen(
|
||||
X,
|
||||
@@ -414,13 +410,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=None,
|
||||
cu_seqlens=cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_ref,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
## chunked seqlen computation
|
||||
@@ -431,10 +427,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
torch.cumsum(chunked_seqlens, dim=0)
|
||||
],
|
||||
dim=0)
|
||||
chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(chunked_seqlens), device=device),
|
||||
chunked_seqlens,
|
||||
output_size=chunked_cu_seqlens[-1]).to(torch.int32)
|
||||
chunked_input_seq_len = chunked_cu_seqlens[-1]
|
||||
X_chunked = torch.zeros_like(X)[:chunked_input_seq_len, ...]
|
||||
dt_chunked = torch.zeros_like(dt)[:chunked_input_seq_len, ...]
|
||||
@@ -450,9 +442,8 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
C_chunked[chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
|
||||
# fmt: on
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(chunked_cu_seqlens, chunk_size))
|
||||
Y_partial = torch.empty_like(X_chunked)
|
||||
partial_state = mamba_chunk_scan_combined_varlen(
|
||||
X_chunked,
|
||||
@@ -461,13 +452,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
B_chunked,
|
||||
C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=chunked_cu_seqlens,
|
||||
seq_idx=chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=None,
|
||||
cu_seqlens=chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_partial,
|
||||
D=None,
|
||||
initial_states=None,
|
||||
)
|
||||
|
||||
# remaining chunk
|
||||
@@ -477,10 +468,6 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
torch.cumsum(remaining_chunked_seqlens, dim=0)
|
||||
],
|
||||
dim=0)
|
||||
remaining_chunked_seq_idx = torch.repeat_interleave(
|
||||
torch.arange(len(remaining_chunked_seqlens), device=device),
|
||||
remaining_chunked_seqlens,
|
||||
output_size=remaining_chunked_cu_seqlens[-1]).to(torch.int32)
|
||||
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
|
||||
# fmt: off
|
||||
remaining_X_chunked = torch.zeros_like(X)[:remaining_chunked_input_seq_len, ...] # noqa: E501
|
||||
@@ -509,11 +496,9 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
|
||||
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
|
||||
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
remaining_chunked_cu_seqlens,
|
||||
chunk_size,
|
||||
remaining_chunked_cu_seqlens[-1])
|
||||
cu_chunk_seqlens, last_chunk_indices, seq_idx_chunks = (
|
||||
compute_varlen_chunk_metadata(remaining_chunked_cu_seqlens,
|
||||
chunk_size))
|
||||
|
||||
Y_chunked = torch.empty_like(remaining_X_chunked)
|
||||
state_chunked = mamba_chunk_scan_combined_varlen(
|
||||
@@ -523,13 +508,13 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
|
||||
remaining_B_chunked,
|
||||
remaining_C_chunked,
|
||||
chunk_size,
|
||||
D=None,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens,
|
||||
seq_idx=remaining_chunked_seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=partial_state,
|
||||
cu_seqlens=remaining_chunked_cu_seqlens.to(torch.int32),
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
seq_idx=seq_idx_chunks,
|
||||
out=Y_chunked,
|
||||
D=None,
|
||||
initial_states=partial_state,
|
||||
)
|
||||
Y = concat_batch_f(Y_partial, Y_chunked)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user