|
|
|
|
@@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
|
|
|
|
constexpr bool kHasZ = Ktraits::kHasZ;
|
|
|
|
|
constexpr bool kVarlen = Ktraits::kVarlen;
|
|
|
|
|
constexpr int kNThreads = Ktraits::kNThreads;
|
|
|
|
|
constexpr int kNItems = Ktraits::kNItems;
|
|
|
|
|
constexpr int kNRows = Ktraits::kNRows;
|
|
|
|
|
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
|
|
|
|
@@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
|
|
|
|
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
|
|
|
|
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
constexpr int kChunkSize = kNThreads * kNItems;
|
|
|
|
|
|
|
|
|
|
// Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility
|
|
|
|
|
const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048;
|
|
|
|
|
const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size;
|
|
|
|
|
const int block_size = params.cache_enabled ? params.block_size : 2048;
|
|
|
|
|
|
|
|
|
|
const int* batch_cache_indices = cache_indices != nullptr ?
|
|
|
|
|
cache_indices + batch_id * params.cache_indices_stride : nullptr;
|
|
|
|
|
@@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
reinterpret_cast<const int*>(params.block_idx_last_scheduled_token_ptr) : nullptr;
|
|
|
|
|
const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ?
|
|
|
|
|
reinterpret_cast<const int*>(params.initial_state_idx_ptr) : nullptr;
|
|
|
|
|
const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ?
|
|
|
|
|
reinterpret_cast<const int*>(params.cu_chunk_seqlen_ptr) : nullptr;
|
|
|
|
|
const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ?
|
|
|
|
|
reinterpret_cast<const int*>(params.last_chunk_indices_ptr) : nullptr;
|
|
|
|
|
|
|
|
|
|
const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index;
|
|
|
|
|
|
|
|
|
|
const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ?
|
|
|
|
|
block_idx_first_scheduled[batch_id] : 0;
|
|
|
|
|
|
|
|
|
|
// Determine chunk boundaries from pre-computed metadata (APC mode)
|
|
|
|
|
// or fall back to simple block_size chunking.
|
|
|
|
|
int first_chunk_idx, n_chunks;
|
|
|
|
|
int current_position;
|
|
|
|
|
|
|
|
|
|
if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) {
|
|
|
|
|
const int last_chunk_idx = last_chunk_indices[batch_id];
|
|
|
|
|
first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1;
|
|
|
|
|
n_chunks = last_chunk_idx - first_chunk_idx + 1;
|
|
|
|
|
// Derive current_position: if the first chunk is partial (fills remainder
|
|
|
|
|
// of a started block), offset into the block accordingly.
|
|
|
|
|
const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx];
|
|
|
|
|
const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size)
|
|
|
|
|
? (block_size - first_chunk_tokens) : 0;
|
|
|
|
|
current_position = block_idx_first * block_size + chunk_start_offset;
|
|
|
|
|
} else {
|
|
|
|
|
first_chunk_idx = 0;
|
|
|
|
|
n_chunks = (seqlen + block_size - 1) / block_size;
|
|
|
|
|
current_position = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int tokens_processed = 0;
|
|
|
|
|
|
|
|
|
|
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
|
|
|
|
const int chunk_tokens = (cu_chunk_seqlen != nullptr)
|
|
|
|
|
? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk]
|
|
|
|
|
: min(block_size, seqlen - tokens_processed);
|
|
|
|
|
if (chunk_tokens <= 0) break;
|
|
|
|
|
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
|
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
@@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
if constexpr (!kDirectIO) {
|
|
|
|
|
if (r > 0) { __syncthreads(); }
|
|
|
|
|
}
|
|
|
|
|
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
|
|
|
|
|
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens);
|
|
|
|
|
if constexpr (!kDirectIO) { __syncthreads(); }
|
|
|
|
|
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
|
|
|
|
|
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens);
|
|
|
|
|
}
|
|
|
|
|
u += kChunkSize;
|
|
|
|
|
delta += kChunkSize;
|
|
|
|
|
u += chunk_tokens;
|
|
|
|
|
delta += chunk_tokens;
|
|
|
|
|
|
|
|
|
|
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
@@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
weight_t B_vals[kNItems], C_vals[kNItems];
|
|
|
|
|
if constexpr (kIsVariableB) {
|
|
|
|
|
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
|
|
|
|
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
|
|
|
|
|
smem_load_weight, chunk_tokens);
|
|
|
|
|
if constexpr (!kIsVariableC) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
|
|
|
@@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
if constexpr (kIsVariableC) {
|
|
|
|
|
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
|
|
|
|
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
|
|
|
|
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1));
|
|
|
|
|
smem_load_weight_C, chunk_tokens);
|
|
|
|
|
if constexpr (!kIsVariableB) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
|
|
|
@@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
for (int i = 0; i < kNItems; ++i) {
|
|
|
|
|
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
|
|
|
|
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
|
|
|
|
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
|
|
|
|
|
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
|
|
|
|
|
thread_data[i] = make_float2(1.f, 0.f);
|
|
|
|
|
}
|
|
|
|
|
if (threadIdx.x * kNItems + i >= chunk_tokens) {
|
|
|
|
|
thread_data[i] = make_float2(1.f, 0.f);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Initialize running total
|
|
|
|
|
@@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
|
smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix;
|
|
|
|
|
|
|
|
|
|
// Store state at the end of each chunk when cache is enabled
|
|
|
|
|
// Store state at the end of each aligned chunk when cache is enabled
|
|
|
|
|
if (params.cache_enabled && batch_cache_indices != nullptr) {
|
|
|
|
|
|
|
|
|
|
size_t cache_slot;
|
|
|
|
|
if (chunk == n_chunks - 1) {
|
|
|
|
|
cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]];
|
|
|
|
|
} else {
|
|
|
|
|
cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk];
|
|
|
|
|
const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size;
|
|
|
|
|
cache_slot = batch_cache_indices[block_idx_completed];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t state_offset = cache_slot * params.ssm_states_batch_stride +
|
|
|
|
|
@@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
|
|
|
|
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
|
|
|
|
+ dim_id * kNRows * params.out_d_stride + tokens_processed;
|
|
|
|
|
__syncthreads();
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
|
|
|
if constexpr (!kDirectIO) {
|
|
|
|
|
if (r > 0) { __syncthreads(); }
|
|
|
|
|
}
|
|
|
|
|
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
|
|
|
|
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (kHasZ) {
|
|
|
|
|
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
|
|
|
|
|
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
|
|
|
|
+ dim_id * kNRows * params.z_d_stride + tokens_processed;
|
|
|
|
|
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
|
|
|
|
|
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
|
|
|
|
+ dim_id * kNRows * params.out_z_d_stride + tokens_processed;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int r = 0; r < kNRows; ++r) {
|
|
|
|
|
input_t z_vals[kNItems];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
|
|
|
|
|
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < kNItems; ++i) {
|
|
|
|
|
float z_val = z_vals[i];
|
|
|
|
|
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
|
|
|
|
|
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Bvar += kChunkSize * 1;
|
|
|
|
|
Cvar += kChunkSize * 1;
|
|
|
|
|
Bvar += chunk_tokens;
|
|
|
|
|
Cvar += chunk_tokens;
|
|
|
|
|
|
|
|
|
|
tokens_processed += chunk_tokens;
|
|
|
|
|
current_position += chunk_tokens;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|
|
|
|
int64_t block_size,
|
|
|
|
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
|
|
|
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
|
|
|
|
const std::optional<torch::Tensor> &initial_state_idx) {
|
|
|
|
|
const std::optional<torch::Tensor> &initial_state_idx,
|
|
|
|
|
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
|
|
|
|
const std::optional<torch::Tensor> &last_chunk_indices) {
|
|
|
|
|
|
|
|
|
|
// Reset the parameters
|
|
|
|
|
memset(¶ms, 0, sizeof(params));
|
|
|
|
|
@@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
|
|
|
|
params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr;
|
|
|
|
|
params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr;
|
|
|
|
|
params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr;
|
|
|
|
|
params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr;
|
|
|
|
|
params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr;
|
|
|
|
|
|
|
|
|
|
// All stride are in elements, not bytes.
|
|
|
|
|
params.A_d_stride = A.stride(0);
|
|
|
|
|
@@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|
|
|
|
int64_t block_size,
|
|
|
|
|
const std::optional<torch::Tensor> &block_idx_first_scheduled_token,
|
|
|
|
|
const std::optional<torch::Tensor> &block_idx_last_scheduled_token,
|
|
|
|
|
const std::optional<torch::Tensor> &initial_state_idx) {
|
|
|
|
|
const std::optional<torch::Tensor> &initial_state_idx,
|
|
|
|
|
const std::optional<torch::Tensor> &cu_chunk_seqlen,
|
|
|
|
|
const std::optional<torch::Tensor> &last_chunk_indices) {
|
|
|
|
|
auto input_type = u.scalar_type();
|
|
|
|
|
auto weight_type = A.scalar_type();
|
|
|
|
|
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
|
|
|
|
@@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
|
|
|
|
block_size,
|
|
|
|
|
block_idx_first_scheduled_token,
|
|
|
|
|
block_idx_last_scheduled_token,
|
|
|
|
|
initial_state_idx
|
|
|
|
|
initial_state_idx,
|
|
|
|
|
cu_chunk_seqlen,
|
|
|
|
|
last_chunk_indices
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|