[V1] [Hybrid] Some additional clean-up in Mamba2 prefix caching (#26222)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -595,21 +595,32 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
if prefix_caching_enabled:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
last_state_idx_d, last_state_idx_p = torch.split(
|
||||
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
current_last_idx_d, current_last_idx_p = torch.split(
|
||||
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
|
||||
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_scheduled_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# Prefill-only variables:
|
||||
current_first_idx_p = attn_metadata.current_first_idx_p
|
||||
context_lens_p = attn_metadata.context_lens_p
|
||||
last_computed_offset_p = attn_metadata.last_computed_offset_p
|
||||
block_idx_first_scheduled_token_p = (
|
||||
attn_metadata.block_idx_first_scheduled_token_p
|
||||
)
|
||||
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||
else:
|
||||
last_state_idx_d, last_state_idx_p = None, None
|
||||
current_last_idx_d, current_last_idx_p = None, None
|
||||
current_first_idx_p = None
|
||||
context_lens_p = None
|
||||
block_idx_last_computed_token_d = None
|
||||
block_idx_last_computed_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_scheduled_token_p = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@@ -637,7 +648,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# to by "state_indices_tensor_p".
|
||||
# In particular, it will always write the state at the
|
||||
# sequence end.
|
||||
# In addition, "current_first_idx_p" and "current_last_idx_p"
|
||||
# In addition, "block_idx_first_scheduled_token_p" and
|
||||
# "block_idx_last_scheduled_token_p"
|
||||
# are provided (which are pointers into
|
||||
# "state_indices_tensor_p"), it will write additional cache
|
||||
# states aligned at "block_size_to_align".
|
||||
@@ -652,10 +664,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
current_first_idx=current_first_idx_p,
|
||||
current_last_idx=current_last_idx_p,
|
||||
initial_state_idx=last_state_idx_p,
|
||||
context_lens=context_lens_p,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
num_computed_tokens=num_computed_tokens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p,
|
||||
@@ -669,7 +681,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, last_state_idx_p.unsqueeze(1)
|
||||
1, block_idx_last_computed_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
@@ -703,52 +715,76 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# Save states for sequences with more than just the final state:
|
||||
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
|
||||
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
|
||||
# The chunk_stride is the number of chunks per mamba block
|
||||
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||
# then chunk_stride = 2
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
|
||||
# Save state for sequences with more than just final state
|
||||
for seq_idx in range(num_prefills):
|
||||
# Block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
|
||||
seq_idx
|
||||
]
|
||||
|
||||
# Block index for the last scheduled token
|
||||
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
|
||||
seq_idx
|
||||
]
|
||||
|
||||
# Number of blocks that need to be written
|
||||
n_blocks_to_fill = (
|
||||
block_idx_last_scheduled_token - block_idx_first_scheduled_token
|
||||
)
|
||||
|
||||
# Skip sequences that don't have any blocks to fill
|
||||
if n_blocks_to_fill == 0:
|
||||
continue
|
||||
|
||||
# Look up the state indices
|
||||
cache_blocks_to_fill = state_indices_tensor_p[
|
||||
seq_idx,
|
||||
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
|
||||
+ n_blocks_to_fill[seq_idx],
|
||||
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
|
||||
]
|
||||
# chunks = [0 1 2 3 4 5 6 ...]
|
||||
# First aligned chunk would typically be:
|
||||
# mamba_block_size = 1024, chunk_size = 256
|
||||
# 1024 // 256 - 1 --> chunks[3]
|
||||
# But when last chunk wasn't block aligned:
|
||||
# - last_computed_offset_p[seq_idx] // chunk_size
|
||||
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
|
||||
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
|
||||
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
|
||||
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
first_aligned_chunk = (
|
||||
torch.concat(
|
||||
[
|
||||
torch.zeros(
|
||||
1,
|
||||
dtype=last_chunk_indices_p.dtype,
|
||||
device=last_chunk_indices_p.device,
|
||||
),
|
||||
last_chunk_indices_p + 1,
|
||||
]
|
||||
)[seq_idx]
|
||||
+ chunk_stride
|
||||
- 1
|
||||
- last_computed_offset_p[seq_idx] // chunk_size
|
||||
|
||||
# First chunk index for this sequence
|
||||
if seq_idx == 0:
|
||||
first_chunk = 0
|
||||
else:
|
||||
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]
|
||||
|
||||
# First chunk that is aligned on the mamba block boundary
|
||||
first_aligned_chunk = first_chunk + chunk_stride - 1
|
||||
|
||||
# Calculate the number of computed tokens that were not
|
||||
# already cached
|
||||
num_unaligned_computed_tokens = (
|
||||
num_computed_tokens_p[seq_idx] % mamba_block_size
|
||||
)
|
||||
|
||||
if num_unaligned_computed_tokens > 0:
|
||||
# If the number of computed tokens is not block aligned,
|
||||
# then we need to shift the index accordingly
|
||||
first_aligned_chunk -= (
|
||||
num_unaligned_computed_tokens // chunk_size
|
||||
)
|
||||
|
||||
# Get states to write
|
||||
from_where = varlen_states[
|
||||
first_aligned_chunk : first_aligned_chunk
|
||||
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
|
||||
+ n_blocks_to_fill * chunk_stride : chunk_stride
|
||||
]
|
||||
|
||||
# Write the states
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
# For all seqs, store the last state (Note: might be partial):
|
||||
# For all seqs, store the last state (note: might be partial):
|
||||
ssm_state[
|
||||
state_indices_tensor_p.gather(
|
||||
1, current_last_idx_p.unsqueeze(1)
|
||||
1, block_idx_last_scheduled_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
] = varlen_states[last_chunk_indices_p]
|
||||
|
||||
else:
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
@@ -759,14 +795,17 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, last_state_idx_d.unsqueeze(1)
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||
1, current_last_idx_d.unsqueeze(1)
|
||||
1, block_idx_last_scheduled_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
# Note:
|
||||
# for decode always: current_first_idx_d == current_last_idx_d
|
||||
# at block boundaries: current_first_idx_d > last_state_idx_d
|
||||
# for decode:
|
||||
# block_idx_first_scheduled_token_d ==
|
||||
# block_idx_last_scheduled_token_d
|
||||
# at block boundaries:
|
||||
# block_idx_first_scheduled_token_d >
|
||||
# block_idx_last_computed_token_d
|
||||
else:
|
||||
# Without caching, read and write in-place to the same blocks:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
@@ -780,8 +819,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
current_last_idx=current_last_idx_d,
|
||||
initial_state_idx=last_state_idx_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
||||
|
||||
Reference in New Issue
Block a user