[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching (#8012)

This commit is contained in:
Tyler Michael Smith
2024-09-17 19:44:27 -04:00
committed by GitHub
parent 09deb4721f
commit 8110e44529
7 changed files with 114 additions and 16 deletions

View File

@@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation) {
bool silu_activation,
const c10::optional<at::Tensor> &conv_state_indices_) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
@@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
const int width = weight.size(-1);
CHECK_SHAPE(x, batch_size, dim);
CHECK_SHAPE(conv_state, batch_size, dim, width);
CHECK_SHAPE(weight, dim, width);
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
@@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2);
if (conv_state_indices_.has_value()) {
auto conv_state_indices = conv_state_indices_.value();
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
TORCH_CHECK(conv_state_indices.is_cuda());
TORCH_CHECK(conv_state_indices.stride(0) == 1)
CHECK_SHAPE(conv_state_indices, batch_size);
int conv_state_entries = conv_state.size(0);
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
} else {
CHECK_SHAPE(conv_state, batch_size, dim, width);
params.conv_state_indices_ptr = nullptr;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
@@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int channel_id = blockIdx.y * kNThreads + tidx;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;