[Kernel][Model] Improve continuous batching for Jamba and Mamba (#9189)

This commit is contained in:
Mor Zusman
2024-10-17 00:12:43 +08:00
committed by GitHub
parent 415f76a9cb
commit fb60ae9b91
15 changed files with 504 additions and 432 deletions

View File

@@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor out,
const c10::optional<at::Tensor>& bias,
bool silu_activation,
int64_t pad_slot_id,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
@@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.dim = dim;
params.seqlen = seqlen;
params.width = width;
params.pad_slot_id = pad_slot_id;
params.silu_activation = silu_activation;
@@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase &params,
}
at::Tensor
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation) {
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE(cache_indices_, batch_size);
}
at::Tensor out = torch::empty_like(x);
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id,
query_start_loc,
cache_indices,
has_initial_state
@@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
});
return out;
}
at::Tensor
causal_conv1d_update(const at::Tensor &x,
void causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation,
const c10::optional<at::Tensor> &cache_seqlens_,
const c10::optional<at::Tensor> &conv_state_indices_) {
const c10::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(bias, dim);
}
at::Tensor out = torch::empty_like(x);
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation);
silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes.
@@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
});
return out;
}
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
@@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
@@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;