[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman
2024-10-17 00:12:43 +08:00
committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
15 changed files with 504 additions and 432 deletions

View File

@@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
@@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const size_t seqlen,
const size_t dstate,
const size_t n_groups,
const size_t n_chunks,
const bool is_variable_B,
const bool is_variable_C,
// device pointers
@@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool varlen) {
bool varlen,
int64_t pad_slot_id) {
// Reset the parameters
memset(&params, 0, sizeof(params));
@@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.seqlen = seqlen;
params.dstate = dstate;
params.n_groups = n_groups;
params.n_chunks = n_chunks;
params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus;
@@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states) {
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
out_z = z;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1);
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);
SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z,
D_,
delta_bias_,
@@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
query_start_loc,
cache_indices,
has_initial_state,
varlen
varlen,
pad_slot_id
);