[Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola
2025-06-07 03:26:11 +02:00
committed by GitHub
parent 6e0cd10f72
commit 84166fee97
26 changed files with 918 additions and 409 deletions

View File

@@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
@@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90(
if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (k >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
@@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides);
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
@@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides);
c_strides, per_act_token, per_out_ch);
}

View File

@@ -76,7 +76,8 @@ void cutlass_group_gemm_caller(
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@@ -84,9 +85,6 @@ void cutlass_group_gemm_caller(
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto options_int =

View File

@@ -7,7 +7,7 @@
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
@@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
}
}
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
@@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
@@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
topk_ids.size(1));
}
__global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
const int32_t* __restrict__ expert_num_tokens,
const int padded_m, const int n,
const int k) {
int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m;
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n;
}
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m,
const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}