[Model] Mamba2 preallocate SSM output tensor to avoid d2d copy overhead (#21075)
Signed-off-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@@ -387,7 +387,8 @@ class Phi4Mamba(nn.Module):
|
||||
has_initial_state=attn_metadata.context_lens_tensor > 0,
|
||||
query_start_loc=attn_metadata.query_start_loc)
|
||||
else:
|
||||
scan_outputs = selective_state_update(
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
@@ -400,7 +401,8 @@ class Phi4Mamba(nn.Module):
|
||||
None if self.yoco_kv else gate.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor)
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
|
||||
@@ -257,7 +257,21 @@ class Plamo2MambaMixer(nn.Module):
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
if has_prefill else None)
|
||||
|
||||
ssd_output_list = []
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
num_prefill_tokens + num_decodes,
|
||||
(self.num_heads // self.tp_size) * self.head_dim
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
@@ -290,7 +304,7 @@ class Plamo2MambaMixer(nn.Module):
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
@@ -312,15 +326,14 @@ class Plamo2MambaMixer(nn.Module):
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
@@ -349,8 +362,7 @@ class Plamo2MambaMixer(nn.Module):
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states_d = selective_state_update(
|
||||
selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states_d,
|
||||
dt,
|
||||
@@ -362,17 +374,13 @@ class Plamo2MambaMixer(nn.Module):
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
assert self.num_heads % self.tp_size == 0
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# Merge prefill and decode outputs before passing to MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
|
||||
# 4. Final linear projection
|
||||
out = self.out_proj(hidden_states)
|
||||
out = self.out_proj(preallocated_ssm_out)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user