[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)

Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-08-02 04:59:34 -04:00
committed by GitHub
parent 25373b6c6c
commit b690e34824
9 changed files with 144 additions and 118 deletions

View File

@@ -541,7 +541,6 @@ class MambaMixer2(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
@@ -583,7 +582,28 @@ class MambaMixer2(MambaBase, CustomOp):
1]
if has_prefill else None)
ssd_output_list = []
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
if envs.VLLM_USE_V1:
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
else:
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if has_prefill:
@@ -623,7 +643,8 @@ class MambaMixer2(MambaBase, CustomOp):
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
@@ -646,15 +667,14 @@ class MambaMixer2(MambaBase, CustomOp):
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
self.head_dim),
)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
@@ -684,8 +704,8 @@ class MambaMixer2(MambaBase, CustomOp):
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state,
hidden_states_d,
dt_d,
@@ -697,26 +717,16 @@ class MambaMixer2(MambaBase, CustomOp):
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
if envs.VLLM_USE_V1:
ssd_output_list.insert(
0,
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
else:
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
hidden_states = self.norm(preallocated_ssm_out,
gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)