[Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
Asaf Gardin
2026-03-01 14:40:23 +02:00
committed by GitHub
parent da543d1abe
commit bbf81f9a92
11 changed files with 251 additions and 146 deletions

View File

@@ -17,7 +17,7 @@
struct SSMParamsBase {
using index_t = size_t;
int batch, dim, seqlen, dstate, n_groups, n_chunks;
int batch, dim, seqlen, dstate, n_groups;
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
@@ -72,6 +72,8 @@ struct SSMParamsBase {
void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write
void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write
void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use
void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets
void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence
};

View File

@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kVarlen = Ktraits::kVarlen;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
constexpr int kNRows = Ktraits::kNRows;
constexpr bool kDirectIO = Ktraits::kDirectIO;
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
// }
constexpr int kChunkSize = kNThreads * kNItems;
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
const int block_size = params.cache_enabled ? params.block_size : 2048;
const int* batch_cache_indices = cache_indices != nullptr ?
cache_indices + batch_id * params.cache_indices_stride : nullptr;
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
block_idx_first_scheduled[batch_id] : 0;
// Determine chunk boundaries from pre-computed metadata (APC mode)
// or fall back to simple block_size chunking.
int first_chunk_idx, n_chunks;
int current_position;
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
const int last_chunk_idx = last_chunk_indices[batch_id];
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
n_chunks = last_chunk_idx - first_chunk_idx + 1;
// Derive current_position: if the first chunk is partial (fills remainder
// of a started block), offset into the block accordingly.
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
? (block_size - first_chunk_tokens) : 0;
current_position = block_idx_first * block_size + chunk_start_offset;
} else {
first_chunk_idx = 0;
n_chunks = (seqlen + block_size - 1) / block_size;
current_position = 0;
}
int tokens_processed = 0;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
: min(block_size, seqlen - tokens_processed);
if (chunk_tokens <= 0) break;
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
__syncthreads();
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
if constexpr (!kDirectIO) { __syncthreads(); }
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
}
u += kChunkSize;
delta += kChunkSize;
u += chunk_tokens;
delta += chunk_tokens;
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
#pragma unroll
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
weight_t B_vals[kNItems], C_vals[kNItems];
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
smem_load_weight, chunk_tokens);
if constexpr (!kIsVariableC) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
smem_load_weight_C, chunk_tokens);
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
for (int i = 0; i < kNItems; ++i) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
}
if (threadIdx.x * kNItems + i >= chunk_tokens) {
thread_data[i] = make_float2(1.f, 0.f);
}
}
// Initialize running total
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if (threadIdx.x == 0) {
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
// Store state at the end of each chunk when cache is enabled
// Store state at the end of each aligned chunk when cache is enabled
if (params.cache_enabled && batch_cache_indices != nullptr) {
size_t cache_slot;
if (chunk == n_chunks - 1) {
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
} else {
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
cache_slot = batch_cache_indices[block_idx_completed];
}
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
__syncthreads();
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
}
if constexpr (kHasZ) {
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
input_t z_vals[kNItems];
__syncthreads();
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
float z_val = z_vals[i];
out_vals[r][i] *= z_val / (1 + expf(-z_val));
}
__syncthreads();
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
}
}
Bvar += kChunkSize * 1;
Cvar += kChunkSize * 1;
Bvar += chunk_tokens;
Cvar += chunk_tokens;
tokens_processed += chunk_tokens;
current_position += chunk_tokens;
}
}
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase &params,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
// Reset the parameters
memset(&params, 0, sizeof(params));
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
int64_t block_size,
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
const std::optional<torch::Tensor> &initial_state_idx) {
const std::optional<torch::Tensor> &initial_state_idx,
const std::optional<torch::Tensor> &cu_chunk_seqlen,
const std::optional<torch::Tensor> &last_chunk_indices) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx
initial_state_idx,
cu_chunk_seqlen,
last_chunk_indices
);

View File

@@ -371,7 +371,9 @@ void selective_scan_fwd(
const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
const std::optional<torch::Tensor>& initial_state_idx);
const std::optional<torch::Tensor>& initial_state_idx,
const std::optional<torch::Tensor>& cu_chunk_seqlen,
const std::optional<torch::Tensor>& last_chunk_indices);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,

View File

@@ -640,7 +640,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int block_size,"
"Tensor? block_idx_first_scheduled_token,"
"Tensor? block_idx_last_scheduled_token,"
"Tensor? initial_state_idx) -> ()");
"Tensor? initial_state_idx,"
"Tensor? cu_chunk_seqlen,"
"Tensor? last_chunk_indices) -> ()");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
// Hadamard transforms