[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-08-02 04:59:34 -04:00
committed by GitHub
parent 25373b6c6c
commit b690e34824
9 changed files with 144 additions and 118 deletions

View File

@@ -365,6 +365,7 @@ def test_selective_state_update(dim, dstate, has_z, itype):
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
@@ -373,16 +374,17 @@ def test_selective_state_update(dim, dstate, has_z, itype):
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True)
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
out=out)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
@@ -581,6 +583,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
],
dim=0)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
@@ -590,18 +593,19 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].clone()
state_before = state.clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out)
out_ref = selective_state_update_ref(state_ref,
x[:batch_size],
dt[:batch_size],
@@ -665,6 +669,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
dtype=torch.int32, device=device)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
out = torch.empty_like(x)
if not tie_hdim:
dt = torch.randn(batch_size,
nheads,
@@ -691,18 +696,19 @@ def test_selective_state_update_with_heads_with_batch_indices(
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
out = selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID)
selective_state_update(state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out)
out_ref = selective_state_update_ref(state_ref,
x,
dt,