[Perf] Only clone when needed for moe_permute (#32273)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-01-20 11:34:39 -05:00
committed by GitHub
parent 4ca62a0dbd
commit 6c97b9b9b6
3 changed files with 6 additions and 5 deletions

View File

@@ -42,7 +42,7 @@ void moe_permute(
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
torch::Tensor topk_ids_for_sort = topk_ids;
auto permuted_experts_id = torch::empty_like(topk_ids);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
@@ -62,12 +62,13 @@ void moe_permute(
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
topk_ids_for_sort = topk_ids.clone();
preprocessTopkIdLauncher(get_ptr<int>(topk_ids_for_sort), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
get_ptr<const int>(topk_ids_for_sort), get_ptr<int>(token_expert_indices),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);

View File

@@ -109,7 +109,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices,
sorted_indices, total_indices, num_experts, expert_first_token_offset);
}
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,

View File

@@ -48,7 +48,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices,
int64_t* expert_first_token_offset,
cudaStream_t stream);
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,