[V1] [Hybrid] Some additional clean-up in Mamba2 prefix caching (#26222)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-10-06 04:40:30 +02:00
committed by GitHub
parent d3c84297c3
commit 778f554157
4 changed files with 171 additions and 136 deletions

View File

@@ -595,21 +595,32 @@ class MambaMixer2(MambaBase, CustomOp):
if prefix_caching_enabled:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
[num_decodes, num_prefills],
dim=0,
)
)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
[num_decodes, num_prefills],
dim=0,
)
)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
last_computed_offset_p = attn_metadata.last_computed_offset_p
block_idx_first_scheduled_token_p = (
attn_metadata.block_idx_first_scheduled_token_p
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
last_state_idx_d, last_state_idx_p = None, None
current_last_idx_d, current_last_idx_p = None, None
current_first_idx_p = None
context_lens_p = None
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
@@ -637,7 +648,8 @@ class MambaMixer2(MambaBase, CustomOp):
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# sequence end.
# In addition, "current_first_idx_p" and "current_last_idx_p"
# In addition, "block_idx_first_scheduled_token_p" and
# "block_idx_last_scheduled_token_p"
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
@@ -652,10 +664,10 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
current_first_idx=current_first_idx_p,
current_last_idx=current_last_idx_p,
initial_state_idx=last_state_idx_p,
context_lens=context_lens_p,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
num_computed_tokens=num_computed_tokens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p,
@@ -669,7 +681,7 @@ class MambaMixer2(MambaBase, CustomOp):
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
@@ -703,52 +715,76 @@ class MambaMixer2(MambaBase, CustomOp):
)
if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
chunk_stride = mamba_block_size // chunk_size
# Save state for sequences with more than just final state
for seq_idx in range(num_prefills):
# Block index for the first scheduled token
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
seq_idx
]
# Block index for the last scheduled token
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
seq_idx
]
# Number of blocks that need to be written
n_blocks_to_fill = (
block_idx_last_scheduled_token - block_idx_first_scheduled_token
)
# Skip sequences that don't have any blocks to fill
if n_blocks_to_fill == 0:
continue
# Look up the state indices
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx,
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
+ n_blocks_to_fill[seq_idx],
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
# 1024 // 256 - 1 --> chunks[3]
# But when last chunk wasn't block aligned:
# - last_computed_offset_p[seq_idx] // chunk_size
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = (
torch.concat(
[
torch.zeros(
1,
dtype=last_chunk_indices_p.dtype,
device=last_chunk_indices_p.device,
),
last_chunk_indices_p + 1,
]
)[seq_idx]
+ chunk_stride
- 1
- last_computed_offset_p[seq_idx] // chunk_size
# First chunk index for this sequence
if seq_idx == 0:
first_chunk = 0
else:
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]
# First chunk that is aligned on the mamba block boundary
first_aligned_chunk = first_chunk + chunk_stride - 1
# Calculate the number of computed tokens that were not
# already cached
num_unaligned_computed_tokens = (
num_computed_tokens_p[seq_idx] % mamba_block_size
)
if num_unaligned_computed_tokens > 0:
# If the number of computed tokens is not block aligned,
# then we need to shift the index accordingly
first_aligned_chunk -= (
num_unaligned_computed_tokens // chunk_size
)
# Get states to write
from_where = varlen_states[
first_aligned_chunk : first_aligned_chunk
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
+ n_blocks_to_fill * chunk_stride : chunk_stride
]
# Write the states
ssm_state[cache_blocks_to_fill] = from_where
# For all seqs, store the last state (Note: might be partial):
# For all seqs, store the last state (note: might be partial):
ssm_state[
state_indices_tensor_p.gather(
1, current_last_idx_p.unsqueeze(1)
1, block_idx_last_scheduled_token_p.unsqueeze(1)
).squeeze(1)
] = varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
@@ -759,14 +795,17 @@ class MambaMixer2(MambaBase, CustomOp):
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, last_state_idx_d.unsqueeze(1)
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, current_last_idx_d.unsqueeze(1)
1, block_idx_last_scheduled_token_d.unsqueeze(1)
).squeeze(1)
# Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
# for decode:
# block_idx_first_scheduled_token_d ==
# block_idx_last_scheduled_token_d
# at block boundaries:
# block_idx_first_scheduled_token_d >
# block_idx_last_computed_token_d
else:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input = state_indices_tensor_d
@@ -780,8 +819,8 @@ class MambaMixer2(MambaBase, CustomOp):
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
current_last_idx=current_last_idx_d,
initial_state_idx=last_state_idx_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)

View File

@@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx, # (batch,)
current_last_idx, # (batch,)
block_idx_first_scheduled_token, # (batch,)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
context_lens, # (batch,)
num_computed_tokens, # (batch,)
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
dim: tl.constexpr,
@@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
# Get the length of the completed sequence so far and compute the offset.
current_first_index = tl.load(current_first_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
sequence_completed_index = tl.load(context_lens + idx_seq)
current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq)
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
sequence_completed_index = tl.load(num_computed_tokens + idx_seq)
# Compute the offset where the first stride_block_m-aligned first full block is
# Value in "token-space"
@@ -476,10 +476,10 @@ def causal_conv1d_fn(
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
current_first_idx: Optional[torch.Tensor] = None,
current_last_idx: Optional[torch.Tensor] = None,
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
context_lens: Optional[torch.Tensor] = None,
num_computed_tokens: Optional[torch.Tensor] = None,
block_size_to_align=0,
metadata=None,
validate_data=False,
@@ -523,13 +523,13 @@ def causal_conv1d_fn(
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
current_first_idx: (batch,), dtype int32
block_idx_first_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the first cache block to be filled is located.
current_last_idx: (batch,), dtype int32
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block containing the initial state is located.
context_lens: (batch,), dtype int32
num_computed_tokens: (batch,), dtype int32
The number of tokens already completed for each sequence
block_size_to_align: int
The block size to align the cached states to
@@ -708,10 +708,10 @@ def causal_conv1d_fn(
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx,
current_last_idx,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
context_lens,
num_computed_tokens,
out,
# Matrix dimensions
dim,
@@ -735,7 +735,7 @@ def causal_conv1d_fn(
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_APC_ENABLED=current_last_idx is not None,
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
# launch_cooperative_grid=True
@@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel(
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
current_last_idx, # (batch,)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
@@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel(
if IS_APC_ENABLED:
# Get the state from the initial_state_idx
conv_state_init = tl.load(initial_state_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
else:
conv_state_init = 0
current_last_index = 0
@@ -1078,7 +1078,7 @@ def causal_conv1d_update(
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
current_last_idx: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
validate_data=False,
):
@@ -1097,7 +1097,7 @@ def causal_conv1d_update(
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
current_last_idx: (batch,), dtype int32
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
@@ -1201,7 +1201,7 @@ def causal_conv1d_update(
conv_state_indices,
num_accepted_tokens,
query_start_loc,
current_last_idx,
block_idx_last_scheduled_token,
initial_state_idx,
out,
# Matrix dimensions
@@ -1230,7 +1230,7 @@ def causal_conv1d_update(
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_APC_ENABLED=current_last_idx is not None,
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,