diff --git a/benchmarks/kernels/benchmark_fused_topk.py b/benchmarks/kernels/benchmark_fused_topk.py new file mode 100644 index 000000000..72bf2d97c --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_topk.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools + +import torch + +from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk +from vllm.triton_utils import triton +from vllm.utils.argparse_utils import FlexibleArgumentParser + +num_tokens_range = [2**i for i in range(0, 8, 2)] +num_experts_range = [16, 32, 64, 128, 256, 512] +topk_range = [3, 4] +configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) + + +def torch_topk( + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "softmax", +): + if scoring_func == "softmax": + scores = torch.softmax(gating_output.float(), dim=-1) + else: + scores = torch.sigmoid(gating_output.float()) + topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_benchmark(scoring_func): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_experts", "topk"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch", "vllm"], + line_names=["Torch", "vLLM"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name=f"fused-topk-perf-{scoring_func}", + args={}, + ) + ) + def benchmark(num_tokens, num_experts, topk, provider): + dtype = torch.bfloat16 + hidden_size = 1024 + renormalize = True + hidden_states = torch.randn( + (num_tokens, hidden_size), dtype=dtype, device="cuda" + ) + gating_output = torch.randn( + (num_tokens, num_experts), dtype=dtype, device="cuda" + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch_topk( + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: fused_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the MoE topk kernel.") + parser.add_argument("--scoring-func", type=str, default="softmax") + parser.add_argument("--save-path", type=str, default="./configs/fused_topk/") + args = parser.parse_args() + + # Get the benchmark function + benchmark = get_benchmark(args.scoring_func) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 337dcc50b..89d54c47d 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -4,7 +4,13 @@ void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, - torch::Tensor& gating_output, bool renormalize); + torch::Tensor& gating_output, bool renormalize, + std::optional bias); + +void topk_sigmoid(torch::Tensor& topk_weights, torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output, bool renormalize, + std::optional bias); void moe_sum(torch::Tensor& input, torch::Tensor& output); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index af6e6fcd4..833036da5 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -62,6 +62,12 @@ __device__ __forceinline__ float toFloat(T value) { } } +// Scoring function enums +enum ScoringFunc { + SCORING_SOFTMAX = 0, // apply softmax + SCORING_SIGMOID = 1 // apply sigmoid +}; + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -125,6 +131,27 @@ __launch_bounds__(TPB) __global__ } } +template +__launch_bounds__(TPB) __global__ + void moeSigmoid(const InputType* input, const bool* finished, float* output, const int num_cols) +{ + const int thread_row_offset = blockIdx.x * num_cols; + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = toFloat(input[idx]); + const float sigmoid_val = 1.0f / (1.0f + __expf(-val)); + output[idx] = sigmoid_val; + } +} + template __launch_bounds__(TPB) __global__ void moeTopK( const float* inputs_after_softmax, @@ -136,7 +163,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( const int k, const int start_expert, const int end_expert, - const bool renormalize) + const bool renormalize, + const float* bias) { using cub_kvp = cub::KeyValuePair; @@ -162,7 +190,13 @@ __launch_bounds__(TPB) __global__ void moeTopK( { const int idx = thread_read_offset + expert; inp_kvp.key = expert; - inp_kvp.value = inputs_after_softmax[idx]; + + // Apply correction bias if provided + if (bias != nullptr) { + inp_kvp.value = inputs_after_softmax[idx] + bias[expert]; + } else { + inp_kvp.value = inputs_after_softmax[idx]; + } for (int prior_k = 0; prior_k < k_idx; ++prior_k) { @@ -186,12 +220,13 @@ __launch_bounds__(TPB) __global__ void moeTopK( const bool should_process_row = row_is_active && node_uses_expert; const int idx = k * block_row + k_idx; - output[idx] = result_kvp.value; + // Return the unbiased scores for output weights + output[idx] = inputs_after_softmax[thread_read_offset + expert]; indices[idx] = should_process_row ? (expert - start_expert) : num_experts; assert(indices[idx] >= 0); source_rows[idx] = k_idx * num_rows + block_row; if (renormalize) { - selected_sum += result_kvp.value; + selected_sum += inputs_after_softmax[thread_read_offset + expert]; } } __syncthreads(); @@ -225,10 +260,12 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ - void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, - int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize) + void topkGating(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices, + int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, + const float* bias) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v, @@ -353,61 +390,89 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ } } - // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just - // convert to float afterwards for the exp + sum reduction. - float thread_max = row_chunk[0]; + if constexpr (SF == SCORING_SOFTMAX) { + // First, we perform a max reduce within the thread. + float thread_max = row_chunk[0]; #pragma unroll - for (int ii = 1; ii < VPT; ++ii) - { + for (int ii = 1; ii < VPT; ++ii) { thread_max = max(thread_max, row_chunk[ii]); - } + } // Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) - { + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); - } + } - // From this point, thread max in all the threads have the max within the row. - // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. - float row_sum = 0; + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; #pragma unroll - for (int ii = 0; ii < VPT; ++ii) - { + for (int ii = 0; ii < VPT; ++ii) + { row_chunk[ii] = expf(row_chunk[ii] - thread_max); row_sum += row_chunk[ii]; - } + } // Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. #pragma unroll - for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) - { + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); - } + } - // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables - // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to - // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. - // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the - // argmax after computing the softmax. - const float reciprocal_row_sum = 1.f / row_sum; + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; #pragma unroll - for (int ii = 0; ii < VPT; ++ii) - { + for (int ii = 0; ii < VPT; ++ii) + { row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + } else if constexpr (SF == SCORING_SIGMOID) { +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = 1.0f / (1.0f + __expf(-row_chunk[ii])); + } } - // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + // If bias is not null, use biased value for selection + float row_chunk_for_choice[VPT]; + // Apply correction bias + if (bias != nullptr) { +#pragma unroll + for (int ldg = 0; ldg < LDG_PER_THREAD; ++ldg) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + const int expert = first_elt_read_by_thread + ldg * COLS_PER_GROUP_LDG + ii; + float bias_val = expert < NUM_EXPERTS ? bias[expert] : 0.0f; + row_chunk_for_choice[ldg * ELTS_PER_LDG + ii] = row_chunk[ldg * ELTS_PER_LDG + ii] + bias_val; + } + } + } else { +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk_for_choice[ii] = row_chunk[ii]; + } + } + + // Now, row_chunk contains the softmax / sigmoid of the row chunk. Now, I want to find the topk elements in each row, along // with the max index. int start_col = first_elt_read_by_thread; - static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; float selected_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax + float max_val_for_choice = row_chunk_for_choice[0]; float max_val = row_chunk[0]; int expert = start_col; #pragma unroll @@ -416,12 +481,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ #pragma unroll for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val_for_choice = row_chunk_for_choice[ldg * ELTS_PER_LDG + ii]; float val = row_chunk[ldg * ELTS_PER_LDG + ii]; // No check on the experts here since columns with the smallest index are processed first and only // updated if > (not >=) - if (val > max_val) + if (val_for_choice > max_val_for_choice) { + max_val_for_choice = val_for_choice; max_val = val; expert = col + ii; } @@ -434,12 +501,14 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max_for_choice = VLLM_SHFL_XOR_SYNC_WIDTH(max_val_for_choice, mask, THREADS_PER_ROW); float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this way - if (other_max > max_val || (other_max == max_val && other_expert < expert)) + if (other_max_for_choice > max_val_for_choice || (other_max_for_choice == max_val_for_choice && other_expert < expert)) { + max_val_for_choice = other_max_for_choice; max_val = other_max; expert = other_expert; } @@ -474,7 +543,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ { const int offset_for_expert = expert % ELTS_PER_LDG; // Safe to set to any negative value since row_chunk values must be between 0 and 1. - row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + row_chunk_for_choice[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; } } } @@ -508,10 +577,10 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, +template +void topkGatingLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize, - cudaStream_t stream) + const float* bias, cudaStream_t stream) { static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS); using Constants = detail::TopkConstants; @@ -521,43 +590,51 @@ void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finishe const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); - topkGatingSoftmax<<>>( - input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize); + topkGating<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize, bias); } #ifndef USE_ROCM -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ - static_assert(WARP_SIZE == 32, \ - "Unsupported warp size. Only 32 is supported for CUDA"); \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ - num_tokens, topk, 0, num_experts, renormalize, stream); + #define LAUNCH_TOPK(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + static_assert(WARP_SIZE == 32, \ + "Unsupported warp size. Only 32 is supported for CUDA"); \ + topkGatingLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \ + bias, stream); #else -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ - if (WARP_SIZE == 64) { \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ - num_tokens, topk, 0, num_experts, renormalize, stream); \ - } else if (WARP_SIZE == 32) { \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \ - num_tokens, topk, 0, num_experts, renormalize, stream); \ - } else { \ - assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ + #define LAUNCH_TOPK(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + if (WARP_SIZE == 64) { \ + topkGatingLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \ + bias, stream); \ + } else if (WARP_SIZE == 32) { \ + topkGatingLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, renormalize, \ + bias, stream); \ + } else { \ + assert(false && \ + "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ } #endif -template -void topkGatingSoftmaxKernelLauncher( +template +void topkGatingKernelLauncher( const InputType* gating_output, float* topk_weights, IndType* topk_indices, int* token_expert_indices, - float* softmax_workspace, + float* workspace, const int num_tokens, const int num_experts, const int topk, const bool renormalize, + const float* bias, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; @@ -569,64 +646,71 @@ void topkGatingSoftmaxKernelLauncher( #endif switch (num_experts) { case 1: - LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 2: - LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 4: - LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 8: - LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 16: - LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 32: - LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 64: - LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 128: - LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 256: - LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 512: - LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + LAUNCH_TOPK(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; // (CUDA only) support multiples of 64 when num_experts is not power of 2. // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts, // alternatively we can test 4 bytes loading and enable it in future. #ifndef USE_ROCM case 192: - LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_TOPK(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; case 320: - LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_TOPK(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; case 384: - LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_TOPK(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; case 448: - LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_TOPK(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; case 576: - LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + LAUNCH_TOPK(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); break; #endif default: { - TORCH_CHECK(softmax_workspace != nullptr, - "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); + TORCH_CHECK(workspace != nullptr, + "workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); static constexpr int TPB = 256; - moeSoftmax<<>>( - gating_output, nullptr, softmax_workspace, num_experts); + if constexpr (SF == SCORING_SOFTMAX) { + moeSoftmax<<>>( + gating_output, nullptr, workspace, num_experts); + } else if constexpr (SF == SCORING_SIGMOID) { + moeSigmoid<<>>( + gating_output, nullptr, workspace, num_experts); + } else { + TORCH_CHECK(false, "Unsupported scoring func"); + } moeTopK<<>>( - softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, - num_experts, topk, 0, num_experts, renormalize); + workspace, nullptr, topk_weights, topk_indices, token_expert_indices, + num_experts, topk, 0, num_experts, renormalize, bias); } } } @@ -635,40 +719,55 @@ void topkGatingSoftmaxKernelLauncher( } // namespace vllm -template -void dispatch_topk_softmax_launch( +template +void dispatch_topk_launch( torch::Tensor& gating_output, torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, torch::Tensor& softmax_workspace, - int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream) -{ + int num_tokens, int num_experts, int topk, bool renormalize, + std::optional bias, + cudaStream_t stream) + { + const float* bias_ptr = nullptr; + if (bias.has_value()) { + const torch::Tensor& bias_tensor = bias.value(); + TORCH_CHECK(bias_tensor.scalar_type() == at::ScalarType::Float, "bias tensor must be float32"); + TORCH_CHECK(bias_tensor.dim() == 1, "bias tensor must be 1D"); + TORCH_CHECK(bias_tensor.size(0) == num_experts, "bias size mismatch, expected: ", num_experts); + TORCH_CHECK(bias_tensor.is_contiguous(), "bias tensor must be contiguous"); + bias_ptr = bias_tensor.data_ptr(); + } + if (topk_indices.scalar_type() == at::ScalarType::Int) { - vllm::moe::topkGatingSoftmaxKernelLauncher( + vllm::moe::topkGatingKernelLauncher( reinterpret_cast(gating_output.data_ptr()), topk_weights.data_ptr(), topk_indices.data_ptr(), token_expert_indices.data_ptr(), softmax_workspace.data_ptr(), - num_tokens, num_experts, topk, renormalize, stream); + num_tokens, num_experts, topk, renormalize, + bias_ptr, stream); } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { - vllm::moe::topkGatingSoftmaxKernelLauncher( + vllm::moe::topkGatingKernelLauncher( reinterpret_cast(gating_output.data_ptr()), topk_weights.data_ptr(), topk_indices.data_ptr(), token_expert_indices.data_ptr(), softmax_workspace.data_ptr(), - num_tokens, num_experts, topk, renormalize, stream); + num_tokens, num_experts, topk, renormalize, + bias_ptr, stream); } else { TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long); - vllm::moe::topkGatingSoftmaxKernelLauncher( + vllm::moe::topkGatingKernelLauncher( reinterpret_cast(gating_output.data_ptr()), topk_weights.data_ptr(), topk_indices.data_ptr(), token_expert_indices.data_ptr(), softmax_workspace.data_ptr(), - num_tokens, num_experts, topk, renormalize, stream); + num_tokens, num_experts, topk, renormalize, + bias_ptr, stream); } } @@ -677,7 +776,8 @@ void topk_softmax( torch::Tensor& topk_indices, // [num_tokens, topk] torch::Tensor& token_expert_indices, // [num_tokens, topk] torch::Tensor& gating_output, // [num_tokens, num_experts] - bool renormalize) + bool renormalize, + std::optional bias) { const int num_experts = gating_output.size(-1); const auto num_tokens = gating_output.numel() / num_experts; @@ -693,14 +793,55 @@ void topk_softmax( torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options); if (gating_output.scalar_type() == at::ScalarType::Float) { - dispatch_topk_softmax_launch(gating_output, topk_weights, topk_indices, - token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + dispatch_topk_launch(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, + bias, stream); } else if (gating_output.scalar_type() == at::ScalarType::Half) { - dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices, - token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + dispatch_topk_launch<__half, vllm::moe::SCORING_SOFTMAX>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, + bias, stream); } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { - dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices, - token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream); + dispatch_topk_launch<__nv_bfloat16, vllm::moe::SCORING_SOFTMAX>(gating_output, topk_weights, topk_indices, + token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, + bias, stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); + } +} + +void topk_sigmoid( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output, // [num_tokens, num_experts] + bool renormalize, + std::optional bias) +{ + const int num_experts = gating_output.size(-1); + const auto num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float); + torch::Tensor workspace = torch::empty({workspace_size}, workspace_options); + + if (gating_output.scalar_type() == at::ScalarType::Float) { + dispatch_topk_launch(gating_output, topk_weights, topk_indices, + token_expert_indices, workspace, num_tokens, num_experts, topk, renormalize, + bias, stream); + } else if (gating_output.scalar_type() == at::ScalarType::Half) { + dispatch_topk_launch<__half, vllm::moe::SCORING_SIGMOID>(gating_output, topk_weights, topk_indices, + token_expert_indices, workspace, num_tokens, num_experts, topk, renormalize, + bias, stream); + } else if (gating_output.scalar_type() == at::ScalarType::BFloat16) { + dispatch_topk_launch<__nv_bfloat16, vllm::moe::SCORING_SIGMOID>(gating_output, topk_weights, topk_indices, + token_expert_indices, workspace, num_tokens, num_experts, topk, renormalize, + bias, stream); } else { TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type()); } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7d44db21d..f8cfe058f 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -5,9 +5,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " - "token_expert_indices, Tensor gating_output, bool renormalize) -> ()"); + "token_expert_indices, Tensor gating_output, bool renormalize, Tensor? " + "bias) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + // Apply topk sigmoid to the gating outputs. + m.def( + "topk_sigmoid(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output, bool renormalize, Tensor? " + "bias) -> ()"); + m.impl("topk_sigmoid", torch::kCUDA, &topk_sigmoid); + // Calculate the result of moe by summing up the partial results // from all selected experts. m.def("moe_sum(Tensor input, Tensor! output) -> ()"); diff --git a/tests/kernels/moe/test_fused_topk.py b/tests/kernels/moe/test_fused_topk.py new file mode 100644 index 000000000..5384d8964 --- /dev/null +++ b/tests/kernels/moe/test_fused_topk.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the MoE fused topk kernel + +Run `pytest tests/kernels/moe/test_fused_topk.py`. +""" + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( + fused_topk_bias, +) +from vllm.model_executor.layers.fused_moe.router.fused_topk_router import fused_topk +from vllm.platforms import current_platform + + +def torch_topk( + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + e_score_correction_bias: torch.Tensor = None, + scoring_func: str = "softmax", +): + if scoring_func == "softmax": + scores = torch.softmax(gating_output.float(), dim=-1) + else: + assert scoring_func == "sigmoid" + scores = torch.sigmoid(gating_output.float()) + + if e_score_correction_bias is not None: + num_experts = gating_output.shape[-1] + scores_for_choice = scores.view( + -1, num_experts + ) + e_score_correction_bias.unsqueeze(0) + _, topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1) + topk_weights = scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(scores, k=topk, dim=-1) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("num_tokens", [1, 33, 56]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [6, 16]) +@pytest.mark.parametrize("topk", [3, 4]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_fused_topk( + num_tokens: int, + hidden_size: int, + num_experts: int, + topk: int, + renormalize: bool, + scoring_func: str, + dtype: torch.dtype, +): + torch.manual_seed(0) + hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda") + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + + topk_weights_ref, topk_ids_ref = torch_topk( + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ) + + topk_weights, topk_ids, _ = fused_topk( + hidden_states=hidden_states, + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ) + + torch.testing.assert_close( + topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("num_tokens", [1, 33, 56]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [6, 16]) +@pytest.mark.parametrize("topk", [3, 4]) +@pytest.mark.parametrize("renormalize", [True, False]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_fused_topk_bias( + num_tokens: int, + hidden_size: int, + num_experts: int, + topk: int, + renormalize: bool, + scoring_func: str, + dtype: torch.dtype, +): + torch.manual_seed(0) + hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda") + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + e_score_correction_bias = torch.randn( + (num_experts,), dtype=torch.float32, device="cuda" + ) + + topk_weights_ref, topk_ids_ref = torch_topk( + gating_output=gating_output, + topk=topk, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + scoring_func=scoring_func, + ) + + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=gating_output, + e_score_correction_bias=e_score_correction_bias, + topk=topk, + renormalize=renormalize, + scoring_func=scoring_func, + ) + + torch.testing.assert_close( + topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(topk_ids_ref.to(torch.int32), topk_ids, atol=0, rtol=0) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 316caf06b..36d7f5cc4 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -18,7 +18,9 @@ from vllm.model_executor.layers.activation import ( SiluAndMul, ) from vllm.model_executor.layers.fused_moe.router.fused_topk_router import ( - dispatch_topk_func, + dispatch_topk_sigmoid_func, + dispatch_topk_softmax_func, + vllm_topk_sigmoid, vllm_topk_softmax, ) from vllm.model_executor.layers.layernorm import ( @@ -133,8 +135,8 @@ def test_enabled_ops_invalid(env: str): @pytest.mark.parametrize( "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] ) -def test_topk_dispatch(use_rocm_aiter: bool): - topk_func = dispatch_topk_func(use_rocm_aiter) +def test_topk_softmax_dispatch(use_rocm_aiter: bool): + topk_func = dispatch_topk_softmax_func(use_rocm_aiter) if current_platform.is_rocm() and use_rocm_aiter: assert topk_func == rocm_aiter_ops.topk_softmax @@ -142,6 +144,18 @@ def test_topk_dispatch(use_rocm_aiter: bool): assert topk_func == vllm_topk_softmax +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False] +) +def test_topk_sigmoid_dispatch(use_rocm_aiter: bool): + topk_func = dispatch_topk_sigmoid_func(use_rocm_aiter) + + if current_platform.is_rocm() and use_rocm_aiter: + assert topk_func == rocm_aiter_ops.topk_sigmoid + else: + assert topk_func == vllm_topk_sigmoid + + @pytest.mark.parametrize("add_residual", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", [True, False]) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 3e232d619..bad6c739c 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -200,6 +200,24 @@ def _rocm_aiter_topk_softmax_fake( pass +def _rocm_aiter_topk_sigmoid_impl( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + gating_output: torch.Tensor, +) -> None: + from aiter import topk_sigmoid + + topk_sigmoid(topk_weights, topk_indices, gating_output) + + +def _rocm_aiter_topk_sigmoid_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + gating_output: torch.Tensor, +) -> None: + pass + + def _rocm_aiter_biased_grouped_topk_impl( gating_output: torch.Tensor, correction_bias: torch.Tensor, @@ -985,6 +1003,14 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_topk_sigmoid", + op_func=_rocm_aiter_topk_sigmoid_impl, + mutates_args=["topk_weights", "topk_indices"], + fake_impl=_rocm_aiter_topk_sigmoid_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_biased_grouped_topk", op_func=_rocm_aiter_biased_grouped_topk_impl, @@ -1272,6 +1298,19 @@ class rocm_aiter_ops: ) return topk_weights, topk_indices + @staticmethod + def topk_sigmoid( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + ) -> tuple[torch.Tensor, ...]: + torch.ops.vllm.rocm_aiter_topk_sigmoid( + topk_weights, topk_indices, gating_output + ) + return topk_weights, topk_indices + @staticmethod def biased_grouped_topk( gating_output: torch.Tensor, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 26931ba29..50f0faf4e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2177,9 +2177,33 @@ def topk_softmax( token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool = False, + e_score_correction_bias: torch.Tensor | None = None, ) -> None: torch.ops._moe_C.topk_softmax( - topk_weights, topk_ids, token_expert_indices, gating_output, renormalize + topk_weights, + topk_ids, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, + ) + + +def topk_sigmoid( + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = False, + e_score_correction_bias: torch.Tensor | None = None, +) -> None: + torch.ops._moe_C.topk_sigmoid( + topk_weights, + topk_ids, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 12e9918e0..8bfcf17ca 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -106,14 +106,14 @@ def _quant_flags_to_group_shape( class RoutingMethodType(IntEnum): # Default: Softmax -> TopK Default = (0,) - # Renormalize: TopK -> Softmax + # Renormalize: TopK -> Softmax/Sigmoid Renormalize = (1,) # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups # -> Top8 experts from the Top4 groups DeepSeekV3 = (2,) # Llama4: Top1 -> Sigmoid Llama4 = (3,) - # RenormalizeNaive: Softmax -> TopK -> Renormalize + # RenormalizeNaive: Softmax/Sigmoid -> TopK -> Renormalize RenormalizeNaive = (4,) # TopK: TopK (no softmax) TopK = (5,) diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 460385ace..44b586650 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -4,6 +4,8 @@ from collections.abc import Callable import torch +import vllm._custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, @@ -12,15 +14,106 @@ from vllm.model_executor.layers.fused_moe.config import RoutingMethodType from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter +def vllm_topk_softmax( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = False, + e_score_correction_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, ...]: + ops.topk_softmax( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, + ) + + return topk_weights, topk_indices + + +def vllm_topk_sigmoid( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = False, + e_score_correction_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, ...]: + ops.topk_sigmoid( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, + ) + + return topk_weights, topk_indices + + def fused_topk_bias( hidden_states: torch.Tensor, gating_output: torch.Tensor, e_score_correction_bias: torch.Tensor, topk: int, renormalize: bool, + scoring_func: str = "softmax", + indices_type: torch.dtype | None = None, ): + if not rocm_aiter_ops.is_fused_moe_enabled(): + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch" + ) + + M, _ = hidden_states.size() + + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty( + M, + topk, + dtype=torch.int32 if indices_type is None else indices_type, + device=hidden_states.device, + ) + token_expert_indices = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + + if scoring_func == "softmax": + topk_weights, topk_ids = vllm_topk_softmax( + topk_weights, + topk_ids, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, + ) + return topk_weights, topk_ids + elif scoring_func == "sigmoid": + topk_weights, topk_ids = vllm_topk_sigmoid( + topk_weights, + topk_ids, + token_expert_indices, + gating_output, + renormalize, + e_score_correction_bias, + ) + return topk_weights, topk_ids + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + n_routed_experts = gating_output.shape[-1] - scores = gating_output.softmax(dim=-1) + if scoring_func == "softmax": + scores = gating_output.softmax(dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + scores_for_choice = scores.view( -1, n_routed_experts ) + e_score_correction_bias.unsqueeze(0) @@ -43,6 +136,7 @@ class FusedTopKBiasRouter(BaseRouter): global_num_experts: int, eplb_state: EplbLayerState, e_score_correction_bias: torch.Tensor, + scoring_func: str, renormalize: bool = True, routed_scaling_factor: float = 1.0, enable_eplb: bool = False, @@ -57,6 +151,7 @@ class FusedTopKBiasRouter(BaseRouter): ) self.e_score_correction_bias = e_score_correction_bias self.renormalize = renormalize + self.scoring_func = scoring_func self.routed_scaling_factor = routed_scaling_factor @property @@ -80,6 +175,7 @@ class FusedTopKBiasRouter(BaseRouter): e_score_correction_bias=self.e_score_correction_bias.data, topk=self.top_k, renormalize=self.renormalize, + scoring_func=self.scoring_func, ) if self.routed_scaling_factor != 1.0: diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index 25b360c52..cec9240ef 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -16,7 +16,7 @@ def vllm_topk_softmax( topk_indices: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, - renormalize: bool, + renormalize: bool = False, ) -> tuple[torch.Tensor, ...]: ops.topk_softmax( topk_weights, @@ -29,7 +29,25 @@ def vllm_topk_softmax( return topk_weights, topk_indices -def dispatch_topk_func( +def vllm_topk_sigmoid( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool = False, +) -> tuple[torch.Tensor, ...]: + ops.topk_sigmoid( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + ) + + return topk_weights, topk_indices + + +def dispatch_topk_softmax_func( use_rocm_aiter: bool = False, ) -> Callable[..., tuple[torch.Tensor, ...]]: if use_rocm_aiter: @@ -37,12 +55,21 @@ def dispatch_topk_func( return vllm_topk_softmax +def dispatch_topk_sigmoid_func( + use_rocm_aiter: bool = False, +) -> Callable[..., tuple[torch.Tensor, ...]]: + if use_rocm_aiter: + return rocm_aiter_ops.topk_sigmoid + return vllm_topk_sigmoid + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, indices_type: torch.dtype | None = None, + scoring_func: str = "softmax", ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" @@ -61,12 +88,26 @@ def fused_topk( M, topk, dtype=torch.int32, device=hidden_states.device ) - topk_func = dispatch_topk_func(use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()) - topk_weights, topk_ids = topk_func( - topk_weights, topk_ids, token_expert_indices, gating_output, renormalize - ) + if scoring_func == "softmax": + topk_func = dispatch_topk_softmax_func( + use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled() + ) + topk_weights, topk_ids = topk_func( + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize + ) - return topk_weights, topk_ids, token_expert_indices + return topk_weights, topk_ids, token_expert_indices + elif scoring_func == "sigmoid": + topk_func = dispatch_topk_sigmoid_func( + use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled() + ) + topk_weights, topk_ids = topk_func( + topk_weights, topk_ids, token_expert_indices, gating_output, renormalize + ) + + return topk_weights, topk_ids, token_expert_indices + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") class FusedTopKRouter(BaseRouter): @@ -82,7 +123,6 @@ class FusedTopKRouter(BaseRouter): enable_eplb: bool = False, indices_type_getter: Callable[[], torch.dtype | None] | None = None, ): - assert scoring_func == "softmax", "FusedTopKRouter only supports softmax." super().__init__( top_k=top_k, global_num_experts=global_num_experts, @@ -91,6 +131,7 @@ class FusedTopKRouter(BaseRouter): indices_type_getter=indices_type_getter, ) self.renormalize = renormalize + self.scoring_func = scoring_func @property def routing_method_type(self) -> RoutingMethodType: @@ -113,6 +154,7 @@ class FusedTopKRouter(BaseRouter): topk=self.top_k, renormalize=self.renormalize, indices_type=indices_type, + scoring_func=self.scoring_func, ) return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index cbe294e6b..d28f07558 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -143,17 +143,13 @@ def create_fused_moe_router( router.capture = capture return router - if scoring_func != "softmax": - raise ValueError( - "Only softmax scoring function is supported for non-grouped topk." - ) - if e_score_correction_bias is not None: router = FusedTopKBiasRouter( top_k=top_k, global_num_experts=global_num_experts, eplb_state=eplb_state, e_score_correction_bias=e_score_correction_bias, + scoring_func=scoring_func, renormalize=renormalize, routed_scaling_factor=routed_scaling_factor, enable_eplb=enable_eplb, diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 292969db6..bcd5d4d1a 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -100,9 +100,6 @@ class MiniMaxM2MoE(nn.Module): num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, scoring_func=config.scoring_func, - use_grouped_topk=True, - num_expert_group=1, - topk_group=1, e_score_correction_bias=self.e_score_correction_bias, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size,