[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
Asaf Gardin
2026-03-01 14:40:23 +02:00
committed by GitHub
parent da543d1abe
commit bbf81f9a92
11 changed files with 251 additions and 146 deletions

View File

@@ -17,7 +17,7 @@
struct SSMParamsBase {
using index_t = size_t;
int batch, dim, seqlen, dstate, n_groups, n_chunks;
int batch, dim, seqlen, dstate, n_groups;
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
@@ -72,6 +72,8 @@ struct SSMParamsBase {
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets
void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence
};

View File

@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kVarlen = Ktraits::kVarlen;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
constexpr int kNRows = Ktraits::kNRows;
constexpr bool kDirectIO = Ktraits::kDirectIO;
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
// }
constexpr int kChunkSize = kNThreads * kNItems;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
const int block_size = params.cache_enabled ? params.block_size : 2048;
const int* batch_cache_indices = cache_indices != nullptr ?
cache_indices + batch_id * params.cache_indices_stride : nullptr;
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
block_idx_first_scheduled[batch_id] : 0;
// Determine chunk boundaries from pre-computed metadata (APC mode)
// or fall back to simple block_size chunking.
int first_chunk_idx, n_chunks;
int current_position;
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
const int last_chunk_idx = last_chunk_indices[batch_id];
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
n_chunks = last_chunk_idx - first_chunk_idx + 1;
// Derive current_position: if the first chunk is partial (fills remainder
// of a started block), offset into the block accordingly.
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
? (block_size - first_chunk_tokens) : 0;
current_position = block_idx_first * block_size + chunk_start_offset;
} else {
first_chunk_idx = 0;
n_chunks = (seqlen + block_size - 1) / block_size;
current_position = 0;
}
int tokens_processed = 0;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
: min(block_size, seqlen - tokens_processed);
if (chunk_tokens <= 0) break;
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
__syncthreads();
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
if constexpr (!kDirectIO) { __syncthreads(); }
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
}
u += kChunkSize;
delta += kChunkSize;
u += chunk_tokens;
delta += chunk_tokens;
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
#pragma unroll
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
weight_t B_vals[kNItems], C_vals[kNItems];
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
smem_load_weight, chunk_tokens);
if constexpr (!kIsVariableC) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
smem_load_weight_C, chunk_tokens);
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for (int i = 0; i < kNItems; ++i) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
}
if (threadIdx.x * kNItems + i >= chunk_tokens) {
thread_data[i] = make_float2(1.f, 0.f);
}
}
// Initialize running total
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) {
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
// Store state at the end of each chunk when cache is enabled
// Store state at the end of each aligned chunk when cache is enabled
if (params.cache_enabled && batch_cache_indices != nullptr) {
size_t cache_slot;
if (chunk == n_chunks - 1) {
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
} else {
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
cache_slot = batch_cache_indices[block_idx_completed];
}
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
__syncthreads();
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
}
if constexpr (kHasZ) {
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
input_t z_vals[kNItems];
__syncthreads();
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
float z_val = z_vals[i];
out_vals[r][i] *= z_val / (1 + expf(-z_val));
}
__syncthreads();
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
}
}
Bvar += kChunkSize * 1;
Cvar += kChunkSize * 1;
Bvar += chunk_tokens;
Cvar += chunk_tokens;
tokens_processed += chunk_tokens;
current_position += chunk_tokens;
}
}
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
// Reset the parameters
memset(&params, 0, sizeof(params));
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices
);

View File

@@ -371,7 +371,9 @@ void selective_scan_fwd(
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx);
const std::optional<torch::Tensor>& initial_state_idx,
const std::optional<torch::Tensor>& cu_chunk_seqlen,
const std::optional<torch::Tensor>& last_chunk_indices);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,

View File

@@ -640,7 +640,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx) -> ()");
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms

View File

@@ -183,6 +183,8 @@ def selective_scan_opcheck_fn(
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
cu_chunk_seqlen=None,
last_chunk_indices=None,
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
@@ -231,6 +233,8 @@ def selective_scan_opcheck_fn(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
),
test_utils=["test_schema", "test_faketensor"],
)

View File

@@ -2021,6 +2021,8 @@ def selective_scan_fwd(
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
cu_chunk_seqlen: torch.Tensor | None = None,
last_chunk_indices: torch.Tensor | None = None,
):
torch.ops._C.selective_scan_fwd(
u,
@@ -2041,6 +2043,8 @@ def selective_scan_fwd(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)

View File

@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
cu_chunk_seqlen=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
)
ssm_outputs.append(scan_out_p)

View File

@@ -497,6 +497,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
cu_chunk_seqlen=None,
last_chunk_indices=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
@@ -588,6 +590,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices,
)
if z is None:

View File

@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from dataclasses import dataclass, replace
from typing import Any
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backend import AttentionBackend, CommonAttentionMetadata
from vllm.v1.attention.backends.mamba_attn import (
BaseMambaAttentionMetadata,
BaseMambaAttentionMetadataBuilder,
@@ -29,3 +30,31 @@ class Mamba1AttentionMetadataBuilder(
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
):
metadata_cls = Mamba1AttentionMetadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
**kwargs: Any,
) -> Mamba1AttentionMetadata:
common = self._compute_common_metadata(common_attn_metadata)
if (
common.num_prefills > 0
and self.vllm_config.cache_config.mamba_cache_mode == "all"
):
cu_chunk_seqlen_p, _, last_chunk_indices_p = (
self._build_chunk_metadata_tensors(
self.kv_cache_spec.block_size,
common,
common_attn_metadata,
)
)
return replace(
common,
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
last_chunk_indices_p=last_chunk_indices_p,
)
return common

View File

@@ -7,7 +7,6 @@ from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
@@ -105,14 +104,6 @@ class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
# Chunk-related metadata (only for prefill)
seq_idx_p: torch.Tensor | None = None
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
class Mamba2AttentionMetadataBuilder(
@@ -134,68 +125,6 @@ class Mamba2AttentionMetadataBuilder(
)
self.chunk_size: int = chunk_size
def _compute_chunk_metadata(
self,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba2.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba2 kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba2 (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % self.chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
- this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, self.chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(self.chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def build(
self,
common_prefix_len: int,
@@ -220,41 +149,12 @@ class Mamba2AttentionMetadataBuilder(
else False
)
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=common_attn_metadata.query_start_loc.device,
dtype=torch.int32,
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
self._build_chunk_metadata_tensors(
self.chunk_size,
common,
common_attn_metadata,
)
)
return replace(

View File

@@ -59,6 +59,15 @@ class BaseMambaAttentionMetadata:
# The following tensor is only used for prefix caching in align mode
seq_lens: torch.Tensor
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
# each chunk, its offsets into the varlen sequence dimension. It is defined
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
# cu_chunk_seqlen_p[i+1].
cu_chunk_seqlen_p: torch.Tensor | None = None
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
# index of the last chunk for every sequence in the (prefill) batch.
last_chunk_indices_p: torch.Tensor | None = None
# The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None
@@ -185,6 +194,118 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
)
def _compute_chunk_metadata(
self,
chunk_size: int,
num_prefills: int,
num_computed_tokens_p_cpu: torch.Tensor,
query_start_loc_p_cpu: torch.Tensor,
) -> tuple[list[int], list[int], list[int]]:
"""
Compute chunk-specific metadata for Mamba models.
The code below carefully constructs the chunks such that:
1. Chunks contain tokens from a *single* sequence only.
2. For every sequence, we are guaranteed that we can
retrieve the mamba state *every* chunk_size tokens.
Constraint (1) dramatically simplifies the mamba kernels.
Constraint (2) dramatically simplifies the implementation
of prefix caching for mamba (wip). We need to take care
of the interaction with chunked prefill in order to
satisfy constraint (2).
"""
# TODO (tdoublep): This code could probably be optimized.
cu_chunk_seqlen = []
seq_idx = []
last_chunk_indices = []
seqlen_pos = 0
for req_idx in range(num_prefills):
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
this_new_tokens = (
query_start_loc_p_cpu[req_idx + 1].item()
- query_start_loc_p_cpu[req_idx].item()
)
# if computed tokens are not chunk-aligned, use the first
# chunk to finish it off
if this_num_computed % chunk_size != 0:
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
# how many tokens to finish the chunk?
chunk_len = (
cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed
)
# we can only use at most this_new_tokens
chunk_len = min(chunk_len, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
n_chunks = cdiv(this_new_tokens, chunk_size)
for chunk in range(n_chunks):
seq_idx.append(req_idx)
cu_chunk_seqlen.append(seqlen_pos)
chunk_len = min(chunk_size, this_new_tokens)
seqlen_pos += chunk_len
this_new_tokens -= chunk_len
assert this_new_tokens == 0
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
cu_chunk_seqlen.append(seqlen_pos)
return cu_chunk_seqlen, seq_idx, last_chunk_indices
def _build_chunk_metadata_tensors(
self,
chunk_size: int,
common: M,
common_attn_metadata: CommonAttentionMetadata,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute chunk metadata and return as device tensors.
Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p).
"""
num_reqs = common.num_reqs
num_prefills = common.num_prefills
num_decode_tokens = common.num_decode_tokens
num_computed_tokens_cpu = (
common_attn_metadata.compute_num_computed_tokens().cpu()
)
num_computed_tokens_p_cpu = num_computed_tokens_cpu[
num_reqs - num_prefills : num_reqs
]
query_start_loc_p_cpu = (
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
- num_decode_tokens
)
cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata(
chunk_size,
num_prefills,
num_computed_tokens_p_cpu,
query_start_loc_p_cpu,
)
device = common_attn_metadata.query_start_loc.device
cu_chunk_seqlen_p = torch.as_tensor(
cu_chunk_seqlen,
device=device,
dtype=torch.int32,
)
seq_idx_p = torch.as_tensor(
seq_idx,
device=device,
dtype=torch.int32,
)
last_chunk_indices_p = torch.as_tensor(
last_chunk_indices,
device=device,
dtype=torch.int32,
)
return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,