[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)
This commit is contained in:
@@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
|
||||
@@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
@@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen) {
|
||||
bool varlen,
|
||||
int64_t pad_slot_id) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
@@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const c10::optional<torch::Tensor> &query_start_loc,
|
||||
const c10::optional<torch::Tensor> &cache_indices,
|
||||
const c10::optional<torch::Tensor> &has_initial_state,
|
||||
const torch::Tensor &ssm_states) {
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
@@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
out_z = z;
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(ssm_states.is_cuda());
|
||||
TORCH_CHECK(ssm_states.stride(-1) == 1);
|
||||
CHECK_SHAPE(ssm_states, batch_size, dim, dstate);
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_,
|
||||
delta_bias_,
|
||||
@@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
varlen
|
||||
varlen,
|
||||
pad_slot_id
|
||||
);
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user