[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-05-06 20:59:30 -04:00
committed by GitHub
parent 6de3e13413
commit 18dd5e01f2
8 changed files with 151 additions and 123 deletions

View File

@@ -388,10 +388,15 @@ class MambaMixer2(CustomOp):
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
seq_len, _ = hidden_states.shape
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
groups_time_state_size = self.n_groups * self.ssm_state_size
# 1. Gated MLP's linear projection
@@ -406,44 +411,32 @@ class MambaMixer2(CustomOp):
dim=-1,
)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if mamba2_metadata.has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C = causal_conv1d_fn(
hidden_states_B_C.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]
# TODO: Why is this needed?
hidden_states_B_C = hidden_states_B_C.contiguous()
else:
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
# - get hidden_states, B and C after depthwise convolution.
hidden_states, B, C = torch.split(
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
@@ -453,24 +446,48 @@ class MambaMixer2(CustomOp):
dim=-1,
)
# 3. State Space Model sequence transformation
if mamba2_metadata.has_prefill:
ssd_output_list = []
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor], 0)
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
self.head_dim),
dt.unsqueeze(0),
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt_p.unsqueeze(0),
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=None,
@@ -478,7 +495,7 @@ class MambaMixer2(CustomOp):
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
@@ -487,52 +504,65 @@ class MambaMixer2(CustomOp):
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] = varlen_state
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
hidden_states = scan_output.view(seq_len, -1)
else:
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A = self.A[:, None, ...][:, :, None].expand(
A_d = self.A[:, None, ...][:, :, None].expand(
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states_reshaped = hidden_states.view(
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into number of current batches
# - in this case there is no more prefill, so the batches gen
# 1 token at a time
# - thus hidden will be (bs, num_heads, head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using "mamba_cache_params.state_indices_tensor", just as
# above in the prefill case
# using state_indices_tensor_d
hidden_states = selective_state_update(
hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states_reshaped,
dt,
A,
B,
C,
D,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
state_batch_indices=state_indices_tensor_d,
)
hidden_states = hidden_states.view(
-1, (self.num_heads // self.tp_size) * self.head_dim)
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# # 4. gated MLP
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP
hidden_states = self.norm(hidden_states, gate)
# # 5. Final linear projection
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out