[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-08-02 04:59:34 -04:00
committed by GitHub
parent 25373b6c6c
commit b690e34824
9 changed files with 144 additions and 118 deletions

View File

@@ -220,7 +220,8 @@ class MambaMixer(CustomOp):
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
@@ -231,7 +232,8 @@ class MambaMixer(CustomOp):
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection