diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index ca0c873f4..fd94b46e5 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -42,7 +42,7 @@ void moe_permute( auto sort_workspace = torch::empty( {sorter_size}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess + torch::Tensor topk_ids_for_sort = topk_ids; auto permuted_experts_id = torch::empty_like(topk_ids); auto sorted_row_idx = torch::empty_like(inv_permuted_idx); @@ -62,12 +62,13 @@ void moe_permute( const int* expert_map_ptr = get_ptr(expert_map.value()); valid_num_ptr = get_ptr(expert_first_token_offset) + n_local_expert; - preprocessTopkIdLauncher(get_ptr(copy_topk_ids), n_token * topk, + topk_ids_for_sort = topk_ids.clone(); + preprocessTopkIdLauncher(get_ptr(topk_ids_for_sort), n_token * topk, expert_map_ptr, n_expert, stream); } // expert sort topk expert id and scan expert id get expert_first_token_offset sortAndScanExpert( - get_ptr(copy_topk_ids), get_ptr(token_expert_indices), + get_ptr(topk_ids_for_sort), get_ptr(token_expert_indices), get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), get_ptr(expert_first_token_offset), n_token, n_expert, n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index 2271c1bc7..9499b297f 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -109,7 +109,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices, sorted_indices, total_indices, num_experts, expert_first_token_offset); } -void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, +void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows, int* permuted_experts, int* permuted_rows, int64_t* expert_first_token_offset, int num_rows, int num_experts, int num_experts_per_node, int k, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 108091efb..2cdefdb91 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -48,7 +48,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices, int64_t* expert_first_token_offset, cudaStream_t stream); -void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, +void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows, int* permuted_experts, int* permuted_rows, int64_t* expert_first_token_offset, int num_rows, int num_experts, int num_experts_per_node, int k,