[Compile] Fix compile warning in moe_permute (#36529)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -73,10 +73,9 @@ void moe_permute(
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, stream);
|
||||
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
|
||||
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
|
||||
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||
@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel(
|
||||
int64_t expanded_dest_row = blockIdx.x;
|
||||
int64_t const expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int expert_id = sorted_experts[expanded_dest_row];
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel(
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
T const* unpermuted_input, T* permuted_output,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||
@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher(
|
||||
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
||||
auto func = func_map[is_check_skip];
|
||||
|
||||
func<<<blocks, threads, 0, stream>>>(
|
||||
unpermuted_input, permuted_output, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts);
|
||||
func<<<blocks, threads, 0, stream>>>(unpermuted_input, permuted_output,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
permuted_idx, expert_first_token_offset,
|
||||
num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
|
||||
Reference in New Issue
Block a user