From e855d380fa59614167362a94e87a21a91f3ab470 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 16 Mar 2026 10:16:14 -0400 Subject: [PATCH] [Compile] Fix compile warning in `moe_permute` (#36529) Signed-off-by: yewentao256 --- csrc/moe/moe_permute_unpermute_op.cu | 7 +++---- .../moe_permute_unpermute_kernel.h | 2 +- .../moe_permute_unpermute_kernel.inl | 17 ++++++++--------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index eec8f9854..c7fcb3ecf 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -73,10 +73,9 @@ void moe_permute( MOE_DISPATCH(input.scalar_type(), [&] { expandInputRowsKernelLauncher( get_ptr(input), get_ptr(permuted_input), - get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), - get_ptr(inv_permuted_idx), get_ptr(permuted_idx), - get_ptr(expert_first_token_offset), n_token, valid_num_ptr, - n_hidden, topk, n_local_expert, stream); + get_ptr(sorted_row_idx), get_ptr(inv_permuted_idx), + get_ptr(permuted_idx), get_ptr(expert_first_token_offset), + n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream); }); } diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 840b47546..fe44d3015 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows, template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t const* expert_first_token_offset, int64_t const num_rows, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index bcb2f9ca5..45d96a270 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -2,7 +2,7 @@ template __global__ void expandInputRowsKernel( - T const* unpermuted_input, T* permuted_output, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t const* expert_first_token_offset, int64_t const num_rows, @@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel( int64_t expanded_dest_row = blockIdx.x; int64_t const expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; - int expert_id = sorted_experts[expanded_dest_row]; if (threadIdx.x == 0) { assert(expanded_dest_row <= INT32_MAX); @@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel( template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t const* expert_first_token_offset, int64_t const num_rows, @@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher( bool is_check_skip = num_valid_tokens_ptr != nullptr; auto func = func_map[is_check_skip]; - func<<>>( - unpermuted_input, permuted_output, sorted_experts, - expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, permuted_idx, - expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, - num_local_experts); + func<<>>(unpermuted_input, permuted_output, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, + permuted_idx, expert_first_token_offset, + num_rows, num_valid_tokens_ptr, cols, k, + num_local_experts); } template