diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 86d9cc184..2a170249b 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -4,7 +4,7 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, - torch::Tensor& gating_output); + torch::Tensor& gating_output, bool renormalize); void moe_sum(torch::Tensor& input, torch::Tensor& output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 53573ada8..af6e6fcd4 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,12 +16,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include #include "../cuda_compat.h" #include "../cub_helpers.h" +#ifndef USE_ROCM + #include + #include +#else + #include + #include + typedef __hip_bfloat16 __nv_bfloat16; + typedef __hip_bfloat162 __nv_bfloat162; +#endif + #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -36,16 +47,27 @@ template < /// Alignment requirement in bytes int Alignment = sizeof(T) * N > -class alignas(Alignment) AlignedArray { - float data[N]; +struct alignas(Alignment) AlignedArray { + T data[N]; }; +template +__device__ __forceinline__ float toFloat(T value) { + if constexpr (std::is_same_v) { + return value; + } else if constexpr (std::is_same_v) { + return __bfloat162float(value); + } else if constexpr (std::is_same_v) { + return __half2float(value); + } +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. -template +template __launch_bounds__(TPB) __global__ - void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) + void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -66,7 +88,8 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData = max(static_cast(input[idx]), threadData); + const float val = toFloat(input[idx]); + threadData = max(val, threadData); } const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp()); @@ -81,7 +104,8 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - threadData += exp((static_cast(input[idx]) - float_max)); + const float val = toFloat(input[idx]); + threadData += expf(val - float_max); } const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp()); @@ -95,8 +119,9 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; - output[idx] = val; + const float val = toFloat(input[idx]); + const float softmax_val = expf(val - float_max) * normalizing_factor; + output[idx] = softmax_val; } } @@ -110,7 +135,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( const int num_experts, const int k, const int start_expert, - const int end_expert) + const int end_expert, + const bool renormalize) { using cub_kvp = cub::KeyValuePair; @@ -125,6 +151,7 @@ __launch_bounds__(TPB) __global__ void moeTopK( const bool row_is_active = finished ? !finished[block_row] : true; const int thread_read_offset = blockIdx.x * num_experts; + float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -163,9 +190,23 @@ __launch_bounds__(TPB) __global__ void moeTopK( indices[idx] = should_process_row ? (expert - start_expert) : num_experts; assert(indices[idx] >= 0); source_rows[idx] = k_idx * num_rows + block_row; + if (renormalize) { + selected_sum += result_kvp.value; + } } __syncthreads(); } + + // Renormalize the k weights for this row to sum to 1, if requested. + if (renormalize) { + if (threadIdx.x == 0) { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] / denom; + } + } + } } // ====================== TopK softmax things =============================== @@ -184,21 +225,30 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, - int* source_rows, const int k, const int start_expert, const int end_expert) + void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, + int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize) { + static_assert(std::is_same_v || std::is_same_v || + std::is_same_v, + "InputType must be float, __nv_bfloat16, or __half"); + // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); // Number of bytes each thread pulls in per load - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + if constexpr (std::is_same_v || std::is_same_v) { + static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0, + "ELTS_PER_LDG must be 1 or even for 16-bit conversion"); + } + // Restrictions based on previous section. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); @@ -236,27 +286,71 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // row it will read. - const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW; // Now, we compute the group each thread belong to in order to determine the first column to start loads. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; - - // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, - // this can support all powers of 2 up to 16. - // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. - // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. - using AccessType = AlignedArray; + const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; // Finally, we pull in the data from global mem float row_chunk[VPT]; - AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); - const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); + + // NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float + if constexpr (std::is_same_v) { + using VecType = AlignedArray; + VecType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll - for (int ii = 0; ii < LDG_PER_THREAD; ++ii) - { - row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2( + *reinterpret_cast(vec.data + jj * 2) + ); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __bfloat162float(*scalar_ptr); + } + } + } else if constexpr (std::is_same_v) { + if constexpr (ELTS_PER_LDG >= 2) { + using VecType = AlignedArray<__half, ELTS_PER_LDG>; + float2* row_chunk_f2 = reinterpret_cast(row_chunk); + const VecType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + int base_idx_f2 = ii * ELTS_PER_LDG / 2; +#pragma unroll + for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) { + row_chunk_f2[base_idx_f2 + jj] = __half22float2( + *reinterpret_cast(vec.data + jj * 2) + ); + } + } + } else { // ELTS_PER_LDG == 1 +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW; + row_chunk[ii] = __half2float(*scalar_ptr); + } + } } // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just @@ -310,6 +404,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax @@ -363,6 +458,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ output[idx] = max_val; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; + if (renormalize) { + selected_sum += max_val; + } } // Finally, we clear the value in the thread with the current max if there is another iteration to run. @@ -380,15 +478,28 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } } + + // Renormalize the k weights for this row to sum to 1, if requested. + if (renormalize) { + if (thread_group_idx == 0) + { + const float denom = selected_sum > 0.f ? selected_sum : 1.f; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] / denom; + } + } + } } namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template +template struct TopkConstants { - static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, ""); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; @@ -397,20 +508,21 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, - int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +template +void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, + cudaStream_t stream) { - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); + using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); - topkGatingSoftmax<<>>( - input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize); } #ifndef USE_ROCM @@ -418,26 +530,26 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f static_assert(WARP_SIZE == 32, \ "Unsupported warp size. Only 32 is supported for CUDA"); \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); #else #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ if (WARP_SIZE == 64) { \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); \ } else if (WARP_SIZE == 32) { \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ + num_tokens, topk, 0, num_experts, renormalize, stream); \ } else { \ assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ } #endif -template +template void topkGatingSoftmaxKernelLauncher( - const float* gating_output, + const InputType* gating_output, float* topk_weights, IndType* topk_indices, int* token_expert_indices, @@ -445,11 +557,15 @@ void topkGatingSoftmaxKernelLauncher( const int num_tokens, const int num_experts, const int topk, + const bool renormalize, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; #ifndef USE_ROCM - static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; + // for bfloat16 dtype, we need 4 bytes loading to make sure num_experts + // elements can be loaded by a warp + static constexpr int BYTES_PER_LDG_MULTIPLE_64 = + (std::is_same_v || std::is_same_v) ? 4 : 8; #endif switch (num_experts) { case 1: @@ -506,11 +622,11 @@ void topkGatingSoftmaxKernelLauncher( TORCH_CHECK(softmax_workspace != nullptr, "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); static constexpr int TPB = 256; - moeSoftmax<<>>( + moeSoftmax<<>>( gating_output, nullptr, softmax_workspace, num_experts); moeTopK<<>>( softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, - num_experts, topk, 0, num_experts); + num_experts, topk, 0, num_experts, renormalize); } } } @@ -518,11 +634,50 @@ void topkGatingSoftmaxKernelLauncher( } // namespace moe } // namespace vllm + +template +void dispatch_topk_softmax_launch( + torch::Tensor& gating_output, + torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& softmax_workspace, + int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream) +{ + if (topk_indices.scalar_type() == at::ScalarType::Int) { + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } else { + TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, num_experts, topk, renormalize, stream); + } +} + void topk_softmax( torch::Tensor& topk_weights, // [num_tokens, topk] torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& token_expert_indices, // [num_tokens, topk] - torch::Tensor& gating_output) // [num_tokens, num_experts] + torch::Tensor& gating_output, // [num_tokens, num_experts] + bool renormalize) { const int num_experts = gating_output.size(-1); const auto num_tokens = gating_output.numel() / num_experts; @@ -534,45 +689,19 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); + const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options); - if(topk_indices.scalar_type() == at::ScalarType::Int) - { - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); - } - else if (topk_indices.scalar_type() == at::ScalarType::UInt32) - { - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); - } - else { - TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + if (gating_output.scalar_type() == at::ScalarType::Float) { + dispatch_topk_softmax_launch(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else if (gating_output.scalar_type() == at::ScalarType::Half) { + dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { + dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 2c0a515ef..8377575ea 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output) -> ()"); + "token_expert_indices, Tensor gating_output, bool renormalize) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); // Calculate the result of moe by summing up the partial results diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index dbbfc01e3..4cfc35448 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1850,9 +1850,10 @@ def topk_softmax( topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, + renormalize: bool = False, ) -> None: torch.ops._moe_C.topk_softmax( - topk_weights, topk_ids, token_expert_indices, gating_output + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 69e32438e..d9007d50e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1074,9 +1074,8 @@ def vllm_topk_softmax( topk_indices, token_expert_indices, gating_output, + renormalize, ) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_indices @@ -1113,11 +1112,9 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. - topk_func = dispatch_topk_func() topk_weights, topk_ids = topk_func( - topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize ) return topk_weights, topk_ids, token_expert_indices