[V1] [Hybrid] Mamba2 Automatic Prefix Caching (#25752)
Signed-off-by: Stanislaw Wozniak <stw@zurich.ibm.com> Signed-off-by: Thomas Ortner <boh@zurich.ibm.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Thomas Ortner <boh@zurich.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -489,6 +489,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
@@ -573,6 +576,25 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
last_state_idx_d, last_state_idx_p = torch.split(
|
||||
attn_metadata.last_state_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
current_last_idx_d, current_last_idx_p = torch.split(
|
||||
attn_metadata.current_last_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
# Prefill-only variables:
|
||||
current_first_idx_p = attn_metadata.current_first_idx_p
|
||||
context_lens_p = attn_metadata.context_lens_p
|
||||
last_computed_offset_p = attn_metadata.last_computed_offset_p
|
||||
else:
|
||||
last_state_idx_d, last_state_idx_p = None, None
|
||||
current_last_idx_d, current_last_idx_p = None, None
|
||||
current_first_idx_p = None
|
||||
context_lens_p = None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
@@ -592,8 +614,17 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
# - It will read the initial states for every sequence,
|
||||
# that has "has_initial_states_p" == True,
|
||||
# from "cache_indices", using "state_indices_tensor_p".
|
||||
# - It updates the "conv_state" cache in positions pointed
|
||||
# to by "state_indices_tensor_p".
|
||||
# In particular, it will always write the state at the
|
||||
# sequence end.
|
||||
# In addition, "current_first_idx_p" and "current_last_idx_p"
|
||||
# are provided (which are pointers into
|
||||
# "state_indices_tensor_p"), it will write additional cache
|
||||
# states aligned at "block_size_to_align".
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
@@ -604,6 +635,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
current_first_idx=current_first_idx_p,
|
||||
current_last_idx=current_last_idx_p,
|
||||
initial_state_idx=last_state_idx_p,
|
||||
context_lens=context_lens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
@@ -614,9 +650,13 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
ssm_state[kernel_ssm_indices], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
@@ -638,18 +678,71 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
if prefix_caching_enabled:
|
||||
# Save states for sequences with more than just the final state:
|
||||
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
|
||||
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
|
||||
cache_blocks_to_fill = state_indices_tensor_p[
|
||||
seq_idx, current_first_idx_p[seq_idx]:
|
||||
current_first_idx_p[seq_idx] +
|
||||
n_blocks_to_fill[seq_idx]]
|
||||
# chunks = [0 1 2 3 4 5 6 ...]
|
||||
# First aligned chunk would typically be:
|
||||
# mamba_block_size = 1024, chunk_size = 256
|
||||
# 1024 // 256 - 1 --> chunks[3]
|
||||
# But when last chunk wasn't block aligned:
|
||||
# - last_computed_offset_p[seq_idx] // chunk_size
|
||||
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
|
||||
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
|
||||
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
|
||||
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
first_aligned_chunk = \
|
||||
torch.concat([torch.zeros(1, \
|
||||
dtype=last_chunk_indices_p.dtype, \
|
||||
device=last_chunk_indices_p.device), \
|
||||
last_chunk_indices_p + 1])[seq_idx] \
|
||||
+ chunk_stride - 1 \
|
||||
- last_computed_offset_p[seq_idx] // chunk_size
|
||||
from_where = varlen_states[
|
||||
first_aligned_chunk:first_aligned_chunk +
|
||||
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
#For all seqs, store the last state (Note: might be partial):
|
||||
ssm_state[state_indices_tensor_p.gather(1,
|
||||
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
|
||||
varlen_states[last_chunk_indices_p]
|
||||
else:
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
# tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
last_state_idx_d.unsqueeze(1)).squeeze(1)
|
||||
state_indices_tensor_d_output = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
current_last_idx_d.unsqueeze(1)).squeeze(1)
|
||||
#Note:
|
||||
# for decode always: current_first_idx_d == current_last_idx_d
|
||||
# at block boundaries: current_first_idx_d > last_state_idx_d
|
||||
else:
|
||||
# Without caching, read and write in-place to the same blocks:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
@@ -657,7 +750,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
current_last_idx=current_last_idx_d,
|
||||
initial_state_idx=last_state_idx_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d)
|
||||
@@ -689,7 +785,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
@@ -20,19 +20,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
initial_states_ptr, # conv_states_ptr
|
||||
cache_indices_ptr, # conv_state_indices_ptr
|
||||
cache_indices_ptr, # (batch, n_blocks + padding) The second dimension contains
|
||||
# the block indices relevant for each sequence
|
||||
# plus potential 0-padding at the beginning and at the end
|
||||
has_initial_states_ptr,
|
||||
query_start_loc_ptr,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
current_first_idx, # (batch,)
|
||||
current_last_idx, # (batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
context_lens, # (batch,)
|
||||
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
|
||||
# Matrix dimensions
|
||||
batch: tl.int32, # actually padded_batch
|
||||
dim: tl.constexpr,
|
||||
seqlen: tl.int32, # cu_seqlen
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_seq: tl.constexpr, # stride to get to next sequence,
|
||||
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
|
||||
stride_x_token: tl.
|
||||
constexpr, # stride to get to next token (same feature-index, same sequence-index)
|
||||
@@ -42,18 +46,16 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
stride_istate_dim: tl.constexpr,
|
||||
stride_istate_token: tl.constexpr,
|
||||
stride_cache_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
stride_block_m: tl.constexpr, # Stride block to align divided by BLOCK_M
|
||||
# others
|
||||
pad_slot_id: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_BIAS: tl.constexpr,
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
HAS_INITIAL_STATES: tl.constexpr,
|
||||
HAS_CACHE: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
@@ -84,26 +86,57 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# find the actual sequence length
|
||||
seqlen = sequence_end_index - sequence_start_index
|
||||
|
||||
B_size: tl.constexpr = (stride_block_m * BLOCK_M)
|
||||
|
||||
if IS_APC_ENABLED:
|
||||
# Handle the case if prefix caching is enabled.
|
||||
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"
|
||||
|
||||
# Get the length of the completed sequence so far and compute the offset.
|
||||
current_first_index = tl.load(current_first_idx + idx_seq)
|
||||
current_last_index = tl.load(current_last_idx + idx_seq)
|
||||
sequence_completed_index = tl.load(context_lens + idx_seq)
|
||||
|
||||
# Compute the offset where the first stride_block_m-aligned first full block is
|
||||
# Value in "token-space"
|
||||
sequence_completed_offset_token = sequence_completed_index % B_size
|
||||
seq_completed_offset = B_size - sequence_completed_offset_token
|
||||
seq_end_offset = (seqlen - seq_completed_offset) % B_size
|
||||
last_full_block_token_index = sequence_end_index - seq_end_offset
|
||||
# If the sequence without the sequence_offset_index is stride_cache_chunk-aligned, then the last full chunk is the second-to-last one
|
||||
if seq_end_offset == 0:
|
||||
last_full_block_token_index = last_full_block_token_index - B_size
|
||||
|
||||
# Get the number of blocks to be filled for the current sequence
|
||||
# If n_block_to_fill = 0, then only the state at the sequence end is stored
|
||||
n_block_to_fill = current_last_index - current_first_index
|
||||
|
||||
# Get the index of the init block
|
||||
conv_state_init_index = tl.load(initial_state_idx + idx_seq)
|
||||
else:
|
||||
n_block_to_fill = 0
|
||||
current_last_index = 0
|
||||
conv_state_init_index = 0
|
||||
current_first_index = 0
|
||||
last_full_block_token_index = 0
|
||||
|
||||
token_offset = BLOCK_M * chunk_offset
|
||||
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||
|
||||
# base of the sequence
|
||||
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = idx_seq
|
||||
# cache_idx
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
conv_state_init_index).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
conv_states_base = (conv_states_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
@@ -113,10 +146,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
|
||||
if chunk_offset == 0:
|
||||
# read from conv_states
|
||||
load_init_state = False
|
||||
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
|
||||
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(
|
||||
tl.int1)
|
||||
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
|
||||
if load_init_state:
|
||||
# load from conv_states
|
||||
prior_tokens = conv_states_base + (state_len -
|
||||
@@ -175,15 +205,23 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
conv_states_ptrs_target = conv_states_base[None, :] + (
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
# Compute the offset where the last block should be written in the conv_states
|
||||
conv_states_output_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
current_last_index).to(tl.int64)
|
||||
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_ptr + (conv_states_output_coord *
|
||||
stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||
< dim)[None, :]
|
||||
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
tl.store(conv_states_ptrs_target, loaded_x, mask)
|
||||
|
||||
else:
|
||||
if load_init_state:
|
||||
@@ -192,12 +230,12 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
|
||||
conv_states_ptrs_source = (
|
||||
conv_states_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
|
||||
None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
mask = ((conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
|
||||
@@ -280,6 +318,45 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
|
||||
# Store intermediate states aligned with stride_block_m
|
||||
# The additional states are cached starting from the last stride_block_m.
|
||||
# For example:
|
||||
# If n_block_to_fill = 0, then only the state at the sequence end is cached and the process below is not involved.
|
||||
# If n_block_to_fill > 0, then the states at the sequence end and at the n_block_to_fill-last
|
||||
# stride_block_m are cached.
|
||||
# For example chunk_offset = n_block_to_fill stores the state at last_full_block
|
||||
if (chunk_offset - 1) < n_block_to_fill:
|
||||
# Store the states at the chunk boundaries from the start of the sequence
|
||||
idx_tokens_last = (last_full_block_token_index -
|
||||
(n_block_to_fill - chunk_offset) * B_size -
|
||||
state_len) + tl.arange(
|
||||
0, NP2_STATELEN) # [BLOCK_M]
|
||||
x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (
|
||||
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
|
||||
|
||||
mask_x = (
|
||||
(idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# cache_idx
|
||||
conv_states_output_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
current_first_index +
|
||||
(chunk_offset - 1)).to(tl.int64)
|
||||
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_ptr + (conv_states_output_coord *
|
||||
stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & \
|
||||
(idx_feats < dim)[None, :]
|
||||
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||
tl.store(conv_states_ptrs_target, loaded_x, mask)
|
||||
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
@@ -368,6 +445,11 @@ def causal_conv1d_fn(
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
current_first_idx: Optional[torch.Tensor] = None,
|
||||
current_last_idx: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
context_lens: Optional[torch.Tensor] = None,
|
||||
block_size_to_align=0,
|
||||
metadata=None,
|
||||
validate_data=False,
|
||||
):
|
||||
@@ -378,7 +460,7 @@ def causal_conv1d_fn(
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
updated inplace if cache_indices are not provided
|
||||
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
|
||||
|
||||
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
|
||||
@@ -410,7 +492,16 @@ def causal_conv1d_fn(
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
|
||||
current_first_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the first cache block to be filled is located.
|
||||
current_last_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the last cache block to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the cache block containing the initial state is located.
|
||||
context_lens: (batch,), dtype int32
|
||||
The number of tokens already completed for each sequence
|
||||
block_size_to_align: int
|
||||
The block size to align the cached states to
|
||||
out: same shape as `x`
|
||||
"""
|
||||
if isinstance(activation, bool) and activation:
|
||||
@@ -451,7 +542,6 @@ def causal_conv1d_fn(
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
padded_batch = query_start_loc.size(0) - 1
|
||||
stride_x_seq = 0
|
||||
stride_x_dim = x.stride(0)
|
||||
stride_x_token = x.stride(1)
|
||||
stride_w_dim = weight.stride(0)
|
||||
@@ -460,6 +550,7 @@ def causal_conv1d_fn(
|
||||
stride_istate_dim = 0
|
||||
stride_istate_token = 0
|
||||
num_cache_lines = 0
|
||||
BLOCK_M = 8
|
||||
if conv_states is not None:
|
||||
# extensions to support vLLM:
|
||||
# 1. conv_states is used to replaced initial_states
|
||||
@@ -475,11 +566,9 @@ def causal_conv1d_fn(
|
||||
stride_istate_token = conv_states.stride(2)
|
||||
assert stride_istate_dim == 1
|
||||
if out.dim() == 2:
|
||||
stride_o_seq = 0
|
||||
stride_o_dim = out.stride(0)
|
||||
stride_o_token = out.stride(1)
|
||||
else:
|
||||
stride_o_seq = out.stride(0)
|
||||
stride_o_dim = out.stride(1)
|
||||
stride_o_token = out.stride(2)
|
||||
stride_cache_indices = cache_indices.stride(
|
||||
@@ -502,6 +591,12 @@ def causal_conv1d_fn(
|
||||
assert weight.stride(1) == 1
|
||||
assert (dim, width) == weight.shape
|
||||
assert is_channel_last, "Need to run in channel-last layout"
|
||||
if block_size_to_align is not None and block_size_to_align > 0:
|
||||
assert (
|
||||
block_size_to_align % BLOCK_M
|
||||
) == 0, "The mamba block size needs to be divisible by the BLOCK_M"
|
||||
else:
|
||||
block_size_to_align = BLOCK_M
|
||||
|
||||
if metadata is None:
|
||||
|
||||
@@ -584,14 +679,16 @@ def causal_conv1d_fn(
|
||||
query_start_loc,
|
||||
batch_ptr,
|
||||
token_chunk_offset_ptr,
|
||||
current_first_idx,
|
||||
current_last_idx,
|
||||
initial_state_idx,
|
||||
context_lens,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
padded_batch,
|
||||
dim,
|
||||
cu_seqlen,
|
||||
num_cache_lines,
|
||||
# stride
|
||||
stride_x_seq,
|
||||
stride_x_dim,
|
||||
stride_x_token,
|
||||
stride_w_dim,
|
||||
@@ -600,22 +697,20 @@ def causal_conv1d_fn(
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_cache_indices,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
block_size_to_align // BLOCK_M,
|
||||
# others
|
||||
pad_slot_id,
|
||||
# META
|
||||
HAS_BIAS=bias is not None,
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
HAS_INITIAL_STATES=has_initial_state is not None,
|
||||
HAS_CACHE=conv_states is not None,
|
||||
IS_CONTINUOUS_BATCHING=cache_indices is not None,
|
||||
IS_APC_ENABLED=current_last_idx is not None,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
#launch_cooperative_grid=True
|
||||
BLOCK_M=8,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=256,
|
||||
num_stages=2,
|
||||
)
|
||||
@@ -629,10 +724,11 @@ def _causal_conv1d_update_kernel(
|
||||
w_ptr, # (dim, width)
|
||||
bias_ptr,
|
||||
conv_state_ptr,
|
||||
cache_seqlens_ptr, # circular buffer
|
||||
conv_state_indices_ptr,
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
current_last_idx, # (batch,)
|
||||
initial_state_idx, #(batch,)
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
@@ -660,7 +756,7 @@ def _causal_conv1d_update_kernel(
|
||||
KERNEL_WIDTH: tl.constexpr,
|
||||
SILU_ACTIVATION: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_APC_ENABLED: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
NP2_STATELEN: tl.constexpr,
|
||||
USE_PAD_SLOT: tl.constexpr,
|
||||
@@ -674,15 +770,21 @@ def _causal_conv1d_update_kernel(
|
||||
# [BLOCK_N,] elements along the feature-dimension (channel)
|
||||
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# mask = idx_seq < batch
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices).to(
|
||||
tl.int64)
|
||||
if IS_APC_ENABLED:
|
||||
# Get the state from the initial_state_idx
|
||||
conv_state_init = tl.load(initial_state_idx + idx_seq)
|
||||
current_last_index = tl.load(current_last_idx + idx_seq)
|
||||
else:
|
||||
conv_state_batch_coord = idx_seq
|
||||
conv_state_init = 0
|
||||
current_last_index = 0
|
||||
|
||||
# cache_idx
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices +
|
||||
conv_state_init).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_state_batch_coord == pad_slot_id:
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
|
||||
@@ -726,7 +828,7 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
@@ -754,12 +856,12 @@ def _causal_conv1d_update_kernel(
|
||||
# window manner, at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
|
||||
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||
mask = ((conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
@@ -778,11 +880,16 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
new_conv_state = tl.where(mask, conv_state, loaded_x)
|
||||
|
||||
conv_state_base = (conv_state_ptr +
|
||||
(conv_state_batch_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
conv_state_ptrs_target = conv_state_base + (
|
||||
idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
# Get the state from the initial_state_idx
|
||||
# cache_idx
|
||||
conv_states_offset = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices +
|
||||
current_last_index).to(tl.int64)
|
||||
conv_state_ptrs_target = (
|
||||
conv_state_ptr +
|
||||
(conv_states_offset * stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens * stride_conv_state_tok)[:, None]
|
||||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||
|
||||
@@ -923,12 +1030,13 @@ def causal_conv1d_update(
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Union[bool, str, None] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
max_query_len: int = -1,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
current_last_idx: Optional[torch.Tensor] = None,
|
||||
initial_state_idx: Optional[torch.Tensor] = None,
|
||||
validate_data=False,
|
||||
):
|
||||
"""
|
||||
@@ -942,15 +1050,14 @@ def causal_conv1d_update(
|
||||
conv_state: (..., dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
current_last_idx: (batch,), dtype int32
|
||||
The pointer into conv_state_indices, where the last cache block to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into conv_state_indices, where the cache block containing the initial state is located.
|
||||
num_accepted_tokens: (batch,), dtype int32
|
||||
If not None, it indicates the number of accepted tokens for each
|
||||
sequence in the batch.
|
||||
@@ -963,15 +1070,14 @@ def causal_conv1d_update(
|
||||
If query_start_loc is not None, this indicates the maximum query
|
||||
length in the batch.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
if conv_state_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
for example: conv_state_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
|
||||
"""
|
||||
if validate_data:
|
||||
assert cache_seqlens is None # not implemented yet - ok for vLLM
|
||||
assert pad_slot_id is not None
|
||||
assert x.stride(1) == 1
|
||||
if isinstance(activation, bool):
|
||||
@@ -1011,7 +1117,6 @@ def causal_conv1d_update(
|
||||
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
assert cache_seqlens is None # not needed for vLLM - circular buffer
|
||||
|
||||
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
|
||||
out = x
|
||||
@@ -1050,10 +1155,11 @@ def causal_conv1d_update(
|
||||
weight,
|
||||
bias,
|
||||
conv_state,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
num_accepted_tokens,
|
||||
query_start_loc,
|
||||
current_last_idx,
|
||||
initial_state_idx,
|
||||
out,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
@@ -1081,7 +1187,7 @@ def causal_conv1d_update(
|
||||
KERNEL_WIDTH=width,
|
||||
SILU_ACTIVATION=activation in ["silu", "swish"],
|
||||
IS_VARLEN=query_start_loc is not None,
|
||||
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
|
||||
IS_APC_ENABLED=current_last_idx is not None,
|
||||
IS_SPEC_DECODING=num_accepted_tokens is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
|
||||
@@ -52,6 +52,7 @@ def _selective_scan_update_kernel(
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
dst_state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
@@ -107,11 +108,17 @@ def _selective_scan_update_kernel(
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
dst_state_batch_indices_ptr += pid_b
|
||||
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
|
||||
dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
else:
|
||||
dst_state_ptr = state_ptr + pid_b * stride_state_batch + \
|
||||
pid_h * stride_state_head
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
@@ -131,6 +138,8 @@ def _selective_scan_update_kernel(
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
@@ -185,7 +194,7 @@ def _selective_scan_update_kernel(
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
tl.store(dst_state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
@@ -205,6 +214,7 @@ def selective_state_update(state,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None):
|
||||
"""
|
||||
@@ -266,6 +276,11 @@ def selective_state_update(state,
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
if dst_state_batch_indices is not None:
|
||||
assert dst_state_batch_indices.shape == (batch, )
|
||||
else:
|
||||
# revert to the default behavior of in-place state updates
|
||||
dst_state_batch_indices = state_batch_indices
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
@@ -292,6 +307,7 @@ def selective_state_update(state,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
dst_state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
|
||||
@@ -35,6 +35,7 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
return_intermediate_states=False,
|
||||
seq_idx=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
@@ -151,28 +152,32 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
return states[last_chunk_indices]
|
||||
if return_intermediate_states:
|
||||
return states
|
||||
else:
|
||||
return states[last_chunk_indices]
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
return_intermediate_states=False,
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
@@ -213,6 +218,7 @@ def mamba_chunk_scan_combined_varlen(
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=return_intermediate_states,
|
||||
seq_idx=seq_idx,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
|
||||
Reference in New Issue
Block a user