From c4e744dbd41f23ee7fd554a86cc1bf516082552d Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 28 Jan 2026 13:15:24 -0500 Subject: [PATCH] [Perf] Optimize `moe_permute` for CUTLASS FP8 (#32892) Signed-off-by: yewentao256 --- csrc/moe/moe_permute_unpermute_op.cu | 33 ++++++++--- .../moe_permute_unpermute_kernel.h | 3 +- .../moe_permute_unpermute_kernel.inl | 55 +++++++------------ 3 files changed, 47 insertions(+), 44 deletions(-) diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index fd94b46e5..3de64eda6 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -73,25 +73,40 @@ void moe_permute( get_ptr(expert_first_token_offset), n_token, n_expert, n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); + // DeepGEMM: use getMIndices kernel to compute + // 1) align_expert_first_token_offset (aligned prefix offsets) + // 2) m_indices (expert id for each aligned row) + // eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively + // expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4 + // expert0: 3->4, expert1: 5->8, expert2: 2->4 + // align_expert_first_token_offset = [0, 4, 12, 16] + // so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2] + torch::Tensor align_expert_first_token_offset; + const int64_t* aligned_expert_first_token_offset_ptr = nullptr; + if (align_block_size.has_value()) { + align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); + aligned_expert_first_token_offset_ptr = + get_ptr(align_expert_first_token_offset); + } + // dispatch expandInputRowsKernelLauncher MOE_DISPATCH(input.scalar_type(), [&] { expandInputRowsKernelLauncher( get_ptr(input), get_ptr(permuted_input), get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), get_ptr(inv_permuted_idx), get_ptr(permuted_idx), - get_ptr(expert_first_token_offset), n_token, valid_num_ptr, - n_hidden, topk, n_local_expert, align_block_size_value, stream); + get_ptr(expert_first_token_offset), + aligned_expert_first_token_offset_ptr, n_token, valid_num_ptr, n_hidden, + topk, n_local_expert, align_block_size_value, stream); }); - // get m_indices and update expert_first_token_offset with align block // this is only required for DeepGemm and not required for CUTLASS group gemm if (align_block_size.has_value()) { - auto align_expert_first_token_offset = - torch::zeros_like(expert_first_token_offset); - getMIndices(get_ptr(expert_first_token_offset), - get_ptr(align_expert_first_token_offset), - get_ptr(m_indices), n_local_expert, align_block_size_value, - stream); expert_first_token_offset.copy_(align_expert_first_token_offset); } } diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 2cdefdb91..09491ab98 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -60,7 +60,8 @@ void expandInputRowsKernelLauncher( T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, - int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* expert_first_token_offset, + int64_t const* aligned_expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream); diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index 449243b92..68f3cc9fa 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -5,7 +5,8 @@ __global__ void expandInputRowsKernel( T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, - int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* expert_first_token_offset, + int64_t const* aligned_expert_first_token_offset, int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols, int64_t k, int num_local_experts, int align_block_size) { // Reverse permutation map. @@ -18,35 +19,22 @@ __global__ void expandInputRowsKernel( expanded_dest_row_to_expanded_source_row[expanded_dest_row]; int expert_id = sorted_experts[expanded_dest_row]; - extern __shared__ int64_t smem_expert_first_token_offset[]; if constexpr (ALIGN_BLOCK_SIZE) { - // load g2s - for (int idx = threadIdx.x; idx < num_local_experts + 1; - idx += blockDim.x) { - smem_expert_first_token_offset[idx] = - __ldg(expert_first_token_offset + idx); + // convert (unaligned) expanded_dest_row -> aligned expanded_dest_row. + // aligned_expert_first_token_offset[e] provides the aligned prefix start + // for expert e. For non-local experts we map to the end (total aligned M). + int64_t aligned_base = 0; + int64_t token_offset_in_expert = 0; + if (expert_id >= num_local_experts) { + aligned_base = + __ldg(aligned_expert_first_token_offset + num_local_experts); + token_offset_in_expert = 0; + } else { + aligned_base = __ldg(aligned_expert_first_token_offset + expert_id); + token_offset_in_expert = + expanded_dest_row - __ldg(expert_first_token_offset + expert_id); } - __syncthreads(); - int lane_idx = threadIdx.x & 31; - - if (lane_idx == 0) { - // set token_offset_in_expert = 0 if this expert is not local expert - int token_offset_in_expert = - expert_id >= num_local_experts - ? 0 - : expanded_dest_row - smem_expert_first_token_offset[expert_id]; - int64_t accumulate_align_offset = 0; -#pragma unroll 1 - for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) { - auto n_token_in_expert = smem_expert_first_token_offset[eidx] - - smem_expert_first_token_offset[eidx - 1]; - accumulate_align_offset += (n_token_in_expert + align_block_size - 1) / - align_block_size * align_block_size; - } - expanded_dest_row = accumulate_align_offset + token_offset_in_expert; - } - // lane0 shuffle broadcast align_expanded_dest_row - expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0); + expanded_dest_row = aligned_base + token_offset_in_expert; } if (threadIdx.x == 0) { @@ -88,7 +76,8 @@ void expandInputRowsKernelLauncher( T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, - int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* expert_first_token_offset, + int64_t const* aligned_expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream) { int64_t const blocks = num_rows * k; @@ -104,14 +93,12 @@ void expandInputRowsKernelLauncher( bool is_align_block_size = align_block_size != -1; auto func = func_map[is_check_skip][is_align_block_size]; - int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); - - func<<>>( + func<<>>( unpermuted_input, permuted_output, sorted_experts, expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, permuted_idx, - expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, - num_local_experts, align_block_size); + expert_first_token_offset, aligned_expert_first_token_offset, num_rows, + num_valid_tokens_ptr, cols, k, num_local_experts, align_block_size); } template