From 6c97b9b9b6337d021615ed74f9cf836b65c446d7 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 20 Jan 2026 11:34:39 -0500 Subject: [PATCH] [Perf] Only clone when needed for `moe_permute` (#32273) Signed-off-by: yewentao256 --- csrc/moe/moe_permute_unpermute_op.cu | 7 ++++--- .../moe_permute_unpermute_kernel.cu | 2 +- .../moe_permute_unpermute_kernel.h | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index ca0c873f4..fd94b46e5 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -42,7 +42,7 @@ void moe_permute( auto sort_workspace = torch::empty( {sorter_size}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess + torch::Tensor topk_ids_for_sort = topk_ids; auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); @@ -62,12 +62,13 @@ void moe_permute( const int* expert_map_ptr = get_ptr(expert_map.value()); valid_num_ptr = get_ptr(expert_first_token_offset) + n_local_expert; - preprocessTopkIdLauncher(get_ptr(copy_topk_ids), n_token * topk, + topk_ids_for_sort = topk_ids.clone(); + preprocessTopkIdLauncher(get_ptr(topk_ids_for_sort), n_token * topk, expert_map_ptr, n_expert, stream); } // expert sort topk expert id and scan expert id get expert_first_token_offset sortAndScanExpert( - get_ptr(copy_topk_ids), get_ptr(token_expert_indices), + get_ptr(topk_ids_for_sort), get_ptr(token_expert_indices), get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), get_ptr(expert_first_token_offset), n_token, n_expert, n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index 2271c1bc7..9499b297f 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -109,7 +109,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices, sorted_indices, total_indices, num_experts, expert_first_token_offset); } -void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, +void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows, int* permuted_experts, int* permuted_rows, int64_t* expert_first_token_offset, int num_rows, int num_experts, int num_experts_per_node, int k, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 108091efb..2cdefdb91 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -48,7 +48,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices, int64_t* expert_first_token_offset, cudaStream_t stream); -void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, +void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows, int* permuted_experts, int* permuted_rows, int64_t* expert_first_token_offset, int num_rows, int num_experts, int num_experts_per_node, int k,