[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@@ -220,7 +220,8 @@ class MambaMixer(CustomOp):
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
@@ -231,7 +232,8 @@ class MambaMixer(CustomOp):
|
||||
gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
|
||||
Reference in New Issue
Block a user