[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -489,6 +489,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
@@ -573,6 +576,25 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
last_state_idx_d, last_state_idx_p = torch.split(
|
||||
attn_metadata.last_state_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
current_last_idx_d, current_last_idx_p = torch.split(
|
||||
attn_metadata.current_last_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
# Prefill-only variables:
|
||||
current_first_idx_p = attn_metadata.current_first_idx_p
|
||||
context_lens_p = attn_metadata.context_lens_p
|
||||
last_computed_offset_p = attn_metadata.last_computed_offset_p
|
||||
else:
|
||||
last_state_idx_d, last_state_idx_p = None, None
|
||||
current_last_idx_d, current_last_idx_p = None, None
|
||||
current_first_idx_p = None
|
||||
context_lens_p = None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
@@ -592,8 +614,17 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
# - It will read the initial states for every sequence,
|
||||
# that has "has_initial_states_p" == True,
|
||||
# from "cache_indices", using "state_indices_tensor_p".
|
||||
# - It updates the "conv_state" cache in positions pointed
|
||||
# to by "state_indices_tensor_p".
|
||||
# In particular, it will always write the state at the
|
||||
# sequence end.
|
||||
# In addition, "current_first_idx_p" and "current_last_idx_p"
|
||||
# are provided (which are pointers into
|
||||
# "state_indices_tensor_p"), it will write additional cache
|
||||
# states aligned at "block_size_to_align".
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
@@ -604,6 +635,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
current_first_idx=current_first_idx_p,
|
||||
current_last_idx=current_last_idx_p,
|
||||
initial_state_idx=last_state_idx_p,
|
||||
context_lens=context_lens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
@@ -614,9 +650,13 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
ssm_state[kernel_ssm_indices], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
@@ -638,18 +678,71 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
if prefix_caching_enabled:
|
||||
# Save states for sequences with more than just the final state:
|
||||
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
|
||||
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
|
||||
cache_blocks_to_fill = state_indices_tensor_p[
|
||||
seq_idx, current_first_idx_p[seq_idx]:
|
||||
current_first_idx_p[seq_idx] +
|
||||
n_blocks_to_fill[seq_idx]]
|
||||
# chunks = [0 1 2 3 4 5 6 ...]
|
||||
# First aligned chunk would typically be:
|
||||
# mamba_block_size = 1024, chunk_size = 256
|
||||
# 1024 // 256 - 1 --> chunks[3]
|
||||
# But when last chunk wasn't block aligned:
|
||||
# - last_computed_offset_p[seq_idx] // chunk_size
|
||||
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
|
||||
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
|
||||
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
|
||||
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
first_aligned_chunk = \
|
||||
torch.concat([torch.zeros(1, \
|
||||
dtype=last_chunk_indices_p.dtype, \
|
||||
device=last_chunk_indices_p.device), \
|
||||
last_chunk_indices_p + 1])[seq_idx] \
|
||||
+ chunk_stride - 1 \
|
||||
- last_computed_offset_p[seq_idx] // chunk_size
|
||||
from_where = varlen_states[
|
||||
first_aligned_chunk:first_aligned_chunk +
|
||||
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
#For all seqs, store the last state (Note: might be partial):
|
||||
ssm_state[state_indices_tensor_p.gather(1,
|
||||
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
|
||||
varlen_states[last_chunk_indices_p]
|
||||
else:
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
# tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
last_state_idx_d.unsqueeze(1)).squeeze(1)
|
||||
state_indices_tensor_d_output = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
current_last_idx_d.unsqueeze(1)).squeeze(1)
|
||||
#Note:
|
||||
# for decode always: current_first_idx_d == current_last_idx_d
|
||||
# at block boundaries: current_first_idx_d > last_state_idx_d
|
||||
else:
|
||||
# Without caching, read and write in-place to the same blocks:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
@@ -657,7 +750,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
current_last_idx=current_last_idx_d,
|
||||
initial_state_idx=last_state_idx_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d)
|
||||
@@ -689,7 +785,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user