[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-08-02 04:59:34 -04:00
committed by GitHub
parent 25373b6c6c
commit b690e34824
9 changed files with 144 additions and 118 deletions

View File

@@ -387,7 +387,8 @@ class Phi4Mamba(nn.Module):
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
@@ -400,7 +401,8 @@ class Phi4Mamba(nn.Module):
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
state_batch_indices=mamba_cache_params.state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection

View File

@@ -257,7 +257,21 @@ class Plamo2MambaMixer(nn.Module):
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
ssd_output_list = []
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if has_prefill:
@@ -290,7 +304,7 @@ class Plamo2MambaMixer(nn.Module):
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
@@ -312,15 +326,14 @@ class Plamo2MambaMixer(nn.Module):
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
@@ -349,8 +362,7 @@ class Plamo2MambaMixer(nn.Module):
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
selective_state_update(
mamba_cache_params.ssm_state,
hidden_states_d,
dt,
@@ -362,17 +374,13 @@ class Plamo2MambaMixer(nn.Module):
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
assert self.num_heads % self.tp_size == 0
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to MLP
hidden_states = torch.vstack(ssd_output_list)
# 4. Final linear projection
out = self.out_proj(hidden_states)
out = self.out_proj(preallocated_ssm_out)
return out