[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
|
||||
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
@@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
||||
|
||||
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
const int32_t* __restrict__ expert_num_tokens,
|
||||
const int padded_m, const int n,
|
||||
const int k) {
|
||||
int expert_idx = threadIdx.x;
|
||||
|
||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m,
|
||||
const int64_t n, const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user