[Kernel][Model] Varlen prefill + Prefill chunking support for mamba kernels and Jamba model (#8533)
This commit is contained in:
@@ -39,8 +39,6 @@
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template <typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
@@ -55,8 +53,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
void* bias_ptr,
|
||||
bool silu_activation) {
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@@ -71,26 +72,31 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias_ptr;
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.x_batch_stride = x.stride(0);
|
||||
params.x_c_stride = x.stride(1);
|
||||
params.x_l_stride = x.stride(-1);
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_c_stride = out.stride(1);
|
||||
params.out_l_stride = out.stride(-1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &seq_idx_,
|
||||
const c10::optional<at::Tensor> &initial_states_,
|
||||
const c10::optional<at::Tensor> &final_states_out_,
|
||||
const c10::optional<at::Tensor> &conv_states,
|
||||
const c10::optional<at::Tensor> &query_start_loc,
|
||||
const c10::optional<at::Tensor> &cache_indices,
|
||||
const c10::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
@@ -99,24 +105,22 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
||||
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
||||
|
||||
if (is_channel_last) {
|
||||
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
||||
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
||||
}
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
@@ -126,56 +130,50 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
||||
auto seq_idx = seq_idx_.value();
|
||||
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(seq_idx.is_cuda());
|
||||
TORCH_CHECK(seq_idx.is_contiguous());
|
||||
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
bias_,
|
||||
silu_activation,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.seq_idx_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (initial_states_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
||||
auto initial_states = initial_states_.value();
|
||||
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(initial_states.is_cuda());
|
||||
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(initial_states.stride(1) == 1);
|
||||
params.initial_states_ptr = initial_states.data_ptr();
|
||||
params.initial_states_batch_stride = initial_states.stride(0);
|
||||
params.initial_states_c_stride = initial_states.stride(1);
|
||||
params.initial_states_l_stride = initial_states.stride(2);
|
||||
} else {
|
||||
params.initial_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (final_states_out_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
||||
auto final_states = final_states_out_.value();
|
||||
TORCH_CHECK(final_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(final_states.is_cuda());
|
||||
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(final_states.stride(1) == 1);
|
||||
params.final_states_ptr = final_states.data_ptr();
|
||||
params.final_states_batch_stride = final_states.stride(0);
|
||||
params.final_states_c_stride = final_states.stride(1);
|
||||
params.final_states_l_stride = final_states.stride(2);
|
||||
} else {
|
||||
params.final_states_ptr = nullptr;
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
@@ -183,11 +181,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
if (!is_channel_last) {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
}
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
@@ -199,6 +193,7 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor> &cache_seqlens_,
|
||||
const c10::optional<at::Tensor> &conv_state_indices_) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
@@ -214,9 +209,12 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim);
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
@@ -232,15 +230,27 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
@@ -249,11 +259,11 @@ causal_conv1d_update(const at::Tensor &x,
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
@@ -296,7 +306,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
@@ -309,20 +319,39 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t zeros[kNElts] = {0};
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
@@ -330,14 +359,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
@@ -375,19 +404,57 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
@@ -422,220 +489,11 @@ void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
||||
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
||||
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
||||
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
||||
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static_assert(kNThreads % 32 == 0);
|
||||
static constexpr int kNWarps = kNThreads / 32;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kChunkSizeL = kChunkSizeL_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
||||
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
||||
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
||||
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
||||
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
||||
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
||||
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
||||
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
||||
// sizeof(typename BlockStoreT::TempStorage)});
|
||||
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kHasSeqIdx>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
||||
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int chunk_l_id = blockIdx.y;
|
||||
const int chunk_c_id = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int l_idx = tid / kNThreadsPerC;
|
||||
const int c_idx = tid % kNThreadsPerC;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
||||
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
||||
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
||||
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
||||
// from the previous L-chunk.
|
||||
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
// Load the elements from the previous chunk that are needed for convolution.
|
||||
if (l_idx < kWidth - 1) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
||||
} else if (initial_states != nullptr
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (final_states != nullptr
|
||||
&& l_idx < kWidth - 1
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
||||
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
||||
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
||||
}
|
||||
|
||||
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
||||
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
||||
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
||||
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
||||
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
||||
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
||||
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
||||
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
||||
static_assert(kNThreadsPerRow <= 32);
|
||||
|
||||
const int row_idx = tid / kNThreadsPerRow;
|
||||
const int col_idx = tid % kNThreadsPerRow;
|
||||
|
||||
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
||||
}
|
||||
}
|
||||
float x_vals[kWidth - 1 + kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
||||
}
|
||||
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
||||
if constexpr (kHasSeqIdx) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
float out_vals[kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
if constexpr (!kHasSeqIdx) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
||||
} else {
|
||||
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
||||
}
|
||||
}
|
||||
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t out_vals_store[kNElts];
|
||||
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
||||
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
||||
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
||||
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
||||
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
||||
dim3 block(Ktraits::kNThreads);
|
||||
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
||||
// if (kSmemSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
// }
|
||||
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
///////
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -649,7 +507,7 @@ struct Causal_conv1d_update_kernel_traits {
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
@@ -660,6 +518,8 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
@@ -675,35 +535,70 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
||||
x_vals[kWidth - 1] = float(x[0]);
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user