[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
batch_size = 1
|
||||
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
@@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
||||
D = torch.randn(dim, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state.detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True)
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
@@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
],
|
||||
dim=0)
|
||||
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
|
||||
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||
@@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].clone()
|
||||
state_before = state.clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=padded_state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x[:batch_size],
|
||||
dt[:batch_size],
|
||||
@@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
dtype=torch.int32, device=device)
|
||||
|
||||
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||
out = torch.empty_like(x)
|
||||
if not tie_hdim:
|
||||
dt = torch.randn(batch_size,
|
||||
nheads,
|
||||
@@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices(
|
||||
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||
z = torch.randn_like(x) if has_z else None
|
||||
state_ref = state[state_indices, :].detach().clone()
|
||||
out = selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID)
|
||||
selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=out)
|
||||
out_ref = selective_state_update_ref(state_ref,
|
||||
x,
|
||||
dt,
|
||||
|
||||
@@ -212,15 +212,16 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
|
||||
|
||||
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
|
||||
B, C, chunk_size)
|
||||
|
||||
Y, final_state = mamba_chunk_scan_combined(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
return_final_states=True)
|
||||
Y = torch.empty_like(X)
|
||||
final_state = mamba_chunk_scan_combined(X,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
return_final_states=True,
|
||||
out=Y)
|
||||
|
||||
# just test the last in sequence
|
||||
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
|
||||
@@ -292,7 +293,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
Y, new_states = mamba_chunk_scan_combined(
|
||||
Y = torch.empty_like(X)
|
||||
new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
dt,
|
||||
A,
|
||||
@@ -306,6 +308,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
chunk_offsets=chunk_offsets,
|
||||
return_varlen_states=True,
|
||||
initial_states=states,
|
||||
out=Y,
|
||||
)
|
||||
|
||||
# just test the last in sequence
|
||||
|
||||
Reference in New Issue
Block a user