[Compile] Fix compile warning in moe_permute (#36529)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -73,10 +73,9 @@ void moe_permute(
|
|||||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||||
expandInputRowsKernelLauncher<scalar_t>(
|
expandInputRowsKernelLauncher<scalar_t>(
|
||||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
get_ptr<int>(sorted_row_idx), get_ptr<int>(inv_permuted_idx),
|
||||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
get_ptr<int>(permuted_idx), get_ptr<int64_t>(expert_first_token_offset),
|
||||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
n_token, valid_num_ptr, n_hidden, topk, n_local_expert, stream);
|
||||||
n_hidden, topk, n_local_expert, stream);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void expandInputRowsKernelLauncher(
|
void expandInputRowsKernelLauncher(
|
||||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
T const* unpermuted_input, T* permuted_output,
|
||||||
int const* expanded_dest_row_to_expanded_source_row,
|
int const* expanded_dest_row_to_expanded_source_row,
|
||||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
template <typename T, bool CHECK_SKIPPED>
|
template <typename T, bool CHECK_SKIPPED>
|
||||||
__global__ void expandInputRowsKernel(
|
__global__ void expandInputRowsKernel(
|
||||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
T const* unpermuted_input, T* permuted_output,
|
||||||
int const* expanded_dest_row_to_expanded_source_row,
|
int const* expanded_dest_row_to_expanded_source_row,
|
||||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||||
@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel(
|
|||||||
int64_t expanded_dest_row = blockIdx.x;
|
int64_t expanded_dest_row = blockIdx.x;
|
||||||
int64_t const expanded_source_row =
|
int64_t const expanded_source_row =
|
||||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||||
int expert_id = sorted_experts[expanded_dest_row];
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
assert(expanded_dest_row <= INT32_MAX);
|
assert(expanded_dest_row <= INT32_MAX);
|
||||||
@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void expandInputRowsKernelLauncher(
|
void expandInputRowsKernelLauncher(
|
||||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
T const* unpermuted_input, T* permuted_output,
|
||||||
int const* expanded_dest_row_to_expanded_source_row,
|
int const* expanded_dest_row_to_expanded_source_row,
|
||||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||||
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
int64_t const* expert_first_token_offset, int64_t const num_rows,
|
||||||
@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher(
|
|||||||
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
||||||
auto func = func_map[is_check_skip];
|
auto func = func_map[is_check_skip];
|
||||||
|
|
||||||
func<<<blocks, threads, 0, stream>>>(
|
func<<<blocks, threads, 0, stream>>>(unpermuted_input, permuted_output,
|
||||||
unpermuted_input, permuted_output, sorted_experts,
|
expanded_dest_row_to_expanded_source_row,
|
||||||
expanded_dest_row_to_expanded_source_row,
|
expanded_source_row_to_expanded_dest_row,
|
||||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
permuted_idx, expert_first_token_offset,
|
||||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
num_rows, num_valid_tokens_ptr, cols, k,
|
||||||
num_local_experts);
|
num_local_experts);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T, class U>
|
template <class T, class U>
|
||||||
|
|||||||
Reference in New Issue
Block a user