[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)

Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Thomas Ortner <boh@zurich.ibm.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Stan Wozniak
2025-10-04 06:34:22 +02:00
committed by GitHub
parent 9705fba7b7
commit ea507c3a93
18 changed files with 917 additions and 147 deletions

View File

@@ -489,6 +489,9 @@ class MambaMixer2(MambaBase, CustomOp):
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
@@ -573,6 +576,25 @@ class MambaMixer2(MambaBase, CustomOp):
dim=0,
)
if prefix_caching_enabled:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills],
dim=0)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills],
dim=0)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
last_computed_offset_p = attn_metadata.last_computed_offset_p
else:
last_state_idx_d, last_state_idx_p = None, None
current_last_idx_d, current_last_idx_p = None, None
current_first_idx_p = None
context_lens_p = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
@@ -592,8 +614,17 @@ class MambaMixer2(MambaBase, CustomOp):
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
# - It will read the initial states for every sequence,
# that has "has_initial_states_p" == True,
# from "cache_indices", using "state_indices_tensor_p".
# - It updates the "conv_state" cache in positions pointed
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# sequence end.
# In addition, "current_first_idx_p" and "current_last_idx_p"
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
@@ -604,6 +635,11 @@ class MambaMixer2(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
current_first_idx=current_first_idx_p,
current_last_idx=current_last_idx_p,
initial_state_idx=last_state_idx_p,
context_lens=context_lens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
@@ -614,9 +650,13 @@ class MambaMixer2(MambaBase, CustomOp):
# 3. State Space Model sequence transformation
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
ssm_state[kernel_ssm_indices], 0)
# NOTE: final output is an in-place update of out tensor
varlen_states = mamba_chunk_scan_combined_varlen(
@@ -638,18 +678,71 @@ class MambaMixer2(MambaBase, CustomOp):
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_states
if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx, current_first_idx_p[seq_idx]:
current_first_idx_p[seq_idx] +
n_blocks_to_fill[seq_idx]]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
# 1024 // 256 - 1 --> chunks[3]
# But when last chunk wasn't block aligned:
# - last_computed_offset_p[seq_idx] // chunk_size
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
dtype=last_chunk_indices_p.dtype, \
device=last_chunk_indices_p.device), \
last_chunk_indices_p + 1])[seq_idx] \
+ chunk_stride - 1 \
- last_computed_offset_p[seq_idx] // chunk_size
from_where = varlen_states[
first_aligned_chunk:first_aligned_chunk +
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
ssm_state[cache_blocks_to_fill] = from_where
#For all seqs, store the last state (Note: might be partial):
ssm_state[state_indices_tensor_p.gather(1,
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = \
state_indices_tensor_d.gather(1,
last_state_idx_d.unsqueeze(1)).squeeze(1)
state_indices_tensor_d_output = \
state_indices_tensor_d.gather(1,
current_last_idx_d.unsqueeze(1)).squeeze(1)
#Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
else:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
@@ -657,7 +750,10 @@ class MambaMixer2(MambaBase, CustomOp):
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)
conv_state_indices=state_indices_tensor_d,
current_last_idx=current_last_idx_d,
initial_state_idx=last_state_idx_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
@@ -689,7 +785,8 @@ class MambaMixer2(MambaBase, CustomOp):
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)

View File

@@ -20,19 +20,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
w_ptr, # (dim, width)
bias_ptr,
initial_states_ptr, # conv_states_ptr
cache_indices_ptr, # conv_state_indices_ptr
cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains
# the block indices relevant for each sequence
# plus potential 0-padding at the beginning and at the end
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx, # (batch,)
current_last_idx, # (batch,)
initial_state_idx, # (batch,)
context_lens, # (batch,)
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
batch: tl.int32, # actually padded_batch
dim: tl.constexpr,
seqlen: tl.int32, # cu_seqlen
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr, # stride to get to next sequence,
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.
constexpr, # stride to get to next token (same feature-index, same sequence-index)
@@ -42,18 +46,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching
stride_istate_dim: tl.constexpr,
stride_istate_token: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
HAS_CACHE: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
@@ -84,26 +86,57 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
B_size: tl.constexpr = (stride_block_m * BLOCK_M)
if IS_APC_ENABLED:
# Handle the case if prefix caching is enabled.
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
# Get the length of the completed sequence so far and compute the offset.
current_first_index = tl.load(current_first_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
sequence_completed_index = tl.load(context_lens + idx_seq)
# Compute the offset where the first stride_block_m-aligned first full block is
# Value in "token-space"
sequence_completed_offset_token = sequence_completed_index % B_size
seq_completed_offset = B_size - sequence_completed_offset_token
seq_end_offset = (seqlen - seq_completed_offset) % B_size
last_full_block_token_index = sequence_end_index - seq_end_offset
# If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one
if seq_end_offset == 0:
last_full_block_token_index = last_full_block_token_index - B_size
# Get the number of blocks to be filled for the current sequence
# If n_block_to_fill = 0, then only the state at the sequence end is stored
n_block_to_fill = current_last_index - current_first_index
# Get the index of the init block
conv_state_init_index = tl.load(initial_state_idx + idx_seq)
else:
n_block_to_fill = 0
current_last_index = 0
conv_state_init_index = 0
current_first_index = 0
last_full_block_token_index = 0
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
else:
# cache_idx
conv_state_batch_coord = idx_seq
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
conv_state_init_index).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
if conv_states_input_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = (conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
@@ -113,10 +146,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
# read from conv_states
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
tl.int1)
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
if load_init_state:
# load from conv_states
prior_tokens = conv_states_base + (state_len -
@@ -175,15 +205,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = conv_states_base[None, :] + (
idx_tokens_conv * stride_conv_state_tok)[:, None]
# Compute the offset where the last block should be written in the conv_states
conv_states_output_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
current_last_index).to(tl.int64)
conv_states_ptrs_target = (
conv_states_ptr + (conv_states_output_coord *
stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, new_conv_state, mask)
tl.store(conv_states_ptrs_target, loaded_x, mask)
else:
if load_init_state:
@@ -192,12 +230,12 @@ def _causal_conv1d_fwd_kernel( # continuous batching
conv_states_ptrs_source = (
conv_states_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
@@ -280,6 +318,45 @@ def _causal_conv1d_fwd_kernel( # continuous batching
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
# Store intermediate states aligned with stride_block_m
# The additional states are cached starting from the last stride_block_m.
# For example:
# If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved.
# If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last
# stride_block_m are cached.
# For example chunk_offset = n_block_to_fill stores the state at last_full_block
if (chunk_offset - 1) < n_block_to_fill:
# Store the states at the chunk boundaries from the start of the sequence
idx_tokens_last = (last_full_block_token_index -
(n_block_to_fill - chunk_offset) * B_size -
state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = (
(idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# cache_idx
conv_states_output_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
current_first_index +
(chunk_offset - 1)).to(tl.int64)
conv_states_ptrs_target = (
conv_states_ptr + (conv_states_output_coord *
stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & \
(idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, loaded_x, mask)
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
@@ -368,6 +445,11 @@ def causal_conv1d_fn(
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
current_first_idx: Optional[torch.Tensor] = None,
current_last_idx: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
context_lens: Optional[torch.Tensor] = None,
block_size_to_align=0,
metadata=None,
validate_data=False,
):
@@ -378,7 +460,7 @@ def causal_conv1d_fn(
sequences are concatenated from left to right for varlen
weight: (dim, width)
conv_states: (...,dim,width - 1) itype
updated inplace if provided
updated inplace if cache_indices are not provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
@@ -410,7 +492,16 @@ def causal_conv1d_fn(
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
current_first_idx: (batch,), dtype int32
The pointer into cache_indices, where the first cache block to be filled is located.
current_last_idx: (batch,), dtype int32
The pointer into cache_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block containing the initial state is located.
context_lens: (batch,), dtype int32
The number of tokens already completed for each sequence
block_size_to_align: int
The block size to align the cached states to
out: same shape as `x`
"""
if isinstance(activation, bool) and activation:
@@ -451,7 +542,6 @@ def causal_conv1d_fn(
np2_statelen = triton.next_power_of_2(state_len)
padded_batch = query_start_loc.size(0) - 1
stride_x_seq = 0
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
@@ -460,6 +550,7 @@ def causal_conv1d_fn(
stride_istate_dim = 0
stride_istate_token = 0
num_cache_lines = 0
BLOCK_M = 8
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
@@ -475,11 +566,9 @@ def causal_conv1d_fn(
stride_istate_token = conv_states.stride(2)
assert stride_istate_dim == 1
if out.dim() == 2:
stride_o_seq = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
else:
stride_o_seq = out.stride(0)
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
stride_cache_indices = cache_indices.stride(
@@ -502,6 +591,12 @@ def causal_conv1d_fn(
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
if block_size_to_align is not None and block_size_to_align > 0:
assert (
block_size_to_align % BLOCK_M
) == 0, "The mamba block size needs to be divisible by the BLOCK_M"
else:
block_size_to_align = BLOCK_M
if metadata is None:
@@ -584,14 +679,16 @@ def causal_conv1d_fn(
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx,
current_last_idx,
initial_state_idx,
context_lens,
out,
# Matrix dimensions
padded_batch,
dim,
cu_seqlen,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
@@ -600,22 +697,20 @@ def causal_conv1d_fn(
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,
block_size_to_align // BLOCK_M,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
HAS_CACHE=conv_states is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
IS_APC_ENABLED=current_last_idx is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
#launch_cooperative_grid=True
BLOCK_M=8,
BLOCK_M=BLOCK_M,
BLOCK_N=256,
num_stages=2,
)
@@ -629,10 +724,11 @@ def _causal_conv1d_update_kernel(
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
current_last_idx, # (batch,)
initial_state_idx, #(batch,)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
@@ -660,7 +756,7 @@ def _causal_conv1d_update_kernel(
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_APC_ENABLED: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
@@ -674,15 +770,21 @@ def _causal_conv1d_update_kernel(
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
if IS_APC_ENABLED:
# Get the state from the initial_state_idx
conv_state_init = tl.load(initial_state_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
else:
conv_state_batch_coord = idx_seq
conv_state_init = 0
current_last_index = 0
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
conv_state_init).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
if conv_states_input_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
@@ -726,7 +828,7 @@ def _causal_conv1d_update_kernel(
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
@@ -754,12 +856,12 @@ def _causal_conv1d_update_kernel(
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
@@ -778,11 +880,16 @@ def _causal_conv1d_update_kernel(
new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
conv_state_ptrs_target = conv_state_base + (
idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
# Get the state from the initial_state_idx
# cache_idx
conv_states_offset = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
current_last_index).to(tl.int64)
conv_state_ptrs_target = (
conv_state_ptr +
(conv_states_offset * stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens * stride_conv_state_tok)[:, None]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
@@ -923,12 +1030,13 @@ def causal_conv1d_update(
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
current_last_idx: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
validate_data=False,
):
"""
@@ -942,15 +1050,14 @@ def causal_conv1d_update(
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
current_last_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
num_accepted_tokens: (batch,), dtype int32
If not None, it indicates the number of accepted tokens for each
sequence in the batch.
@@ -963,15 +1070,14 @@ def causal_conv1d_update(
If query_start_loc is not None, this indicates the maximum query
length in the batch.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
if conv_state_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
@@ -1011,7 +1117,6 @@ def causal_conv1d_update(
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
@@ -1050,10 +1155,11 @@ def causal_conv1d_update(
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
current_last_idx,
initial_state_idx,
out,
# Matrix dimensions
batch,
@@ -1081,7 +1187,7 @@ def causal_conv1d_update(
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_APC_ENABLED=current_last_idx is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,

View File

@@ -52,6 +52,7 @@ def _selective_scan_update_kernel(
z_ptr,
out_ptr,
state_batch_indices_ptr,
dst_state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions
batch,
@@ -107,11 +108,17 @@ def _selective_scan_update_kernel(
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
dst_state_batch_indices_ptr += pid_b
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
else:
dst_state_ptr = state_ptr + pid_b * stride_state_batch + \
pid_h * stride_state_head
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
@@ -131,6 +138,8 @@ def _selective_scan_update_kernel(
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
@@ -185,7 +194,7 @@ def _selective_scan_update_kernel(
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
tl.store(state_ptrs, state, mask=mask)
tl.store(dst_state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
@@ -205,6 +214,7 @@ def selective_state_update(state,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None):
"""
@@ -266,6 +276,11 @@ def selective_state_update(state,
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape == (batch, )
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
@@ -292,6 +307,7 @@ def selective_state_update(state,
z,
out,
state_batch_indices,
dst_state_batch_indices,
pad_slot_id,
batch,
nheads,

View File

@@ -35,6 +35,7 @@ def _mamba_chunk_scan_combined_fwd(x,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_seqlens=None,
cu_chunk_seqlens=None,
@@ -151,28 +152,32 @@ def _mamba_chunk_scan_combined_fwd(x,
initial_states=initial_states,
)
return states[last_chunk_indices]
if return_intermediate_states:
return states
else:
return states[last_chunk_indices]
def mamba_chunk_scan_combined_varlen(
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_intermediate_states=False,
state_dtype=None,
):
"""
Argument:
@@ -213,6 +218,7 @@ def mamba_chunk_scan_combined_varlen(
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
return_intermediate_states=return_intermediate_states,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
cu_chunk_seqlens=cu_chunk_seqlens,