From 085252764710fdb42ac180983ef4c37da0ed72d3 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 8 Nov 2025 10:20:55 +0800 Subject: [PATCH] [Perf][DeepSeek] Add sigmoid+bias fusion to fused_grouped_topk from TRTLLM (#28124) Signed-off-by: mgoin Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- csrc/moe/grouped_topk_kernels.cu | 153 +++++++++++------- csrc/moe/moe_ops.h | 6 +- csrc/moe/torch_bindings.cpp | 5 +- vllm/_custom_ops.py | 19 ++- .../layers/fused_moe/fused_moe.py | 41 +++-- 5 files changed, 149 insertions(+), 75 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index c93f9d54d..69b4c1fb1 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) { #endif } +// Scoring function enums +enum ScoringFunc { + SCORING_NONE = 0, // no activation function + SCORING_SIGMOID = 1 // apply sigmoid +}; + +// Efficient sigmoid approximation from TensorRT-LLM +__device__ inline float sigmoid_accurate(float x) { + return 0.5f * tanhf(0.5f * x) + 0.5f; +} + template -__device__ void topk_with_k2(T* output, T const* input, +__device__ inline T apply_sigmoid(T val) { + float f = cuda_cast(val); + return cuda_cast(sigmoid_accurate(f)); +} + +template +__device__ void topk_with_k2(T* output, T const* input, T const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, - int const num_experts_per_group) { + int const num_experts_per_group, + int const scoring_func) { // Get the top2 per thread T largest = neg_inf(); T second_largest = neg_inf(); @@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input, if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { T value = input[i]; + // Apply scoring function if needed + if (scoring_func == SCORING_SIGMOID) { + value = apply_sigmoid(value); + } + value = value + bias[i]; + if (value > largest) { second_largest = largest; largest = value; @@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input, } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - largest = input[i]; + T value = input[i]; + // Apply scoring function if needed + if (scoring_func == SCORING_SIGMOID) { + value = apply_sigmoid(value); + } + value = value + bias[i]; + largest = value; } } @@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input, } template -__global__ void topk_with_k2_kernel(T* output, T* input, +__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, - int64_t const num_experts_per_group) { + int64_t const num_experts_per_group, + int const scoring_func) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; if (case_id < num_cases) { input += case_id * num_experts_per_group; + // bias is per expert group, offset to current group + int32_t group_id = case_id % n_group; + T const* group_bias = bias + group_id * num_experts_per_group; output += case_id; cg::thread_block block = cg::this_thread_block(); @@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, tile, lane_id, num_experts_per_group); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group, scoring_func); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); @@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input, template __global__ void group_idx_and_topk_idx_kernel( - T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, - T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, + T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, + T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool renormalize, - double routed_scaling_factor) { + double routed_scaling_factor, int scoring_func) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token - scores_with_bias += case_id * num_experts; scores += case_id * num_experts; group_scores += case_id * n_group; topk_values += case_id * topk; @@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = (i < num_experts_per_group) && - is_finite(scores_with_bias[offset + i]) - ? scores_with_bias[offset + i] - : neg_inf(); + T candidates = neg_inf(); + if (i < num_experts_per_group) { + // Apply scoring function (if any) and add bias + T input = scores[offset + i]; + if (is_finite(input)) { + T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) + : input; + candidates = score + bias[offset + i]; + } + } queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel( for (int i = lane_id; i < warp_topk::round_up_to_multiple_of(topk); i += WARP_SIZE) { - T value = - i < topk - ? scores[s_topk_idx[i]] - : cuda_cast(0.0f); // Load the valid value of expert + T value = cuda_cast(0.0f); if (i < topk) { + // Load the score value (without bias) for normalization + T input = scores[s_topk_idx[i]]; + value = + (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input; s_topk_value[i] = value; } topk_sum += @@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel( value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; } topk_indices[i] = s_topk_idx[i]; - topk_values[i] = cuda_cast(value); + topk_values[i] = value; } } else { for (int i = lane_id; i < topk; i += WARP_SIZE) { topk_indices[i] = i; - topk_values[i] = cuda_cast(1.0f / topk); + topk_values[i] = 1.0f / topk; } } // Note: when if_proceed_next_topk==false, choose the first 8 experts as the @@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel( } template -void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, - IdxT* topk_indices, T* scores_with_bias, - int64_t const num_tokens, int64_t const num_experts, - int64_t const n_group, int64_t const topk_group, - int64_t const topk, bool const renormalize, - double const routed_scaling_factor, bool enable_pdl = false, +void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, + IdxT* topk_indices, T const* bias, int64_t const num_tokens, + int64_t const num_experts, int64_t const n_group, + int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, + int const scoring_func, bool enable_pdl = false, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, - num_tokens, num_cases, n_group, num_experts / n_group); + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts / n_group, + scoring_func); int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, scores_with_bias, num_tokens, - n_group, topk_group, topk, num_experts, - num_experts / n_group, renormalize, routed_scaling_factor); + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts / n_group, + renormalize, routed_scaling_factor, scoring_func); } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ template void invokeNoAuxTc( \ - T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ - T * scores_with_bias, int64_t const num_tokens, \ - int64_t const num_experts, int64_t const n_group, \ - int64_t const topk_group, int64_t const topk, bool const renormalize, \ - double const routed_scaling_factor, bool enable_pdl, \ - cudaStream_t const stream); + T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ + T const* bias, int64_t const num_tokens, int64_t const num_experts, \ + int64_t const n_group, int64_t const topk_group, int64_t const topk, \ + bool const renormalize, double const routed_scaling_factor, \ + int const scoring_func, bool enable_pdl, cudaStream_t const stream); INSTANTIATE_NOAUX_TC(float, int32_t); INSTANTIATE_NOAUX_TC(half, int32_t); @@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); } // namespace vllm std::tuple grouped_topk( - torch::Tensor const& scores, torch::Tensor const& scores_with_bias, - int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, - double routed_scaling_factor) { - auto data_type = scores_with_bias.scalar_type(); - auto input_size = scores_with_bias.sizes(); + torch::Tensor const& scores, int64_t n_group, int64_t topk_group, + int64_t topk, bool renormalize, double routed_scaling_factor, + torch::Tensor const& bias, int64_t scoring_func = 0) { + auto data_type = scores.scalar_type(); + auto input_size = scores.sizes(); int64_t num_tokens = input_size[0]; int64_t num_experts = input_size[1]; - TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); + TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor"); TORCH_CHECK(num_experts % n_group == 0, "num_experts should be divisible by n_group"); TORCH_CHECK(n_group <= 32, "n_group should be smaller than or equal to 32 for now"); TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); + TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE || + scoring_func == vllm::moe::SCORING_SIGMOID, + "scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)"); torch::Tensor group_scores = torch::empty( {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); + // Always output float32 for topk_values (eliminates Python-side conversion) torch::Tensor topk_values = torch::empty( - {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); + {num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); torch::Tensor topk_indices = torch::empty( {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); + auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device()); switch (data_type) { case torch::kFloat16: @@ -732,11 +777,11 @@ std::tuple grouped_topk( vllm::moe::invokeNoAuxTc( reinterpret_cast(scores.mutable_data_ptr()), reinterpret_cast(group_scores.mutable_data_ptr()), - reinterpret_cast(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + reinterpret_cast(bias.data_ptr()), num_tokens, num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, false, stream); + routed_scaling_factor, static_cast(scoring_func), false, stream); break; case torch::kFloat32: // Handle Float32 @@ -745,20 +790,20 @@ std::tuple grouped_topk( reinterpret_cast(group_scores.mutable_data_ptr()), reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast(scores_with_bias.data_ptr()), num_tokens, + reinterpret_cast(bias.data_ptr()), num_tokens, num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, false, stream); + routed_scaling_factor, static_cast(scoring_func), false, stream); break; case torch::kBFloat16: // Handle BFloat16 vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), + reinterpret_cast(topk_values.mutable_data_ptr()), reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), - num_tokens, num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, false, stream); + reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens, + num_experts, n_group, topk_group, topk, renormalize, + routed_scaling_factor, static_cast(scoring_func), false, stream); break; default: // Handle other data types diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 0adf74568..11c6875f7 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -39,9 +39,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_K, int64_t bit); std::tuple grouped_topk( - torch::Tensor const& scores, torch::Tensor const& scores_with_bias, - int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, - double routed_scaling_factor); + torch::Tensor const& scores, int64_t n_group, int64_t topk_group, + int64_t topk, bool renormalize, double routed_scaling_factor, + torch::Tensor const& bias, int64_t scoring_func); #endif bool moe_permute_unpermute_supported(); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index ace72fad7..bd95ade40 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -107,9 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply grouped topk routing to select experts. m.def( - "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " + "grouped_topk(Tensor scores, int n_group, int " "topk_group, int topk, bool renormalize, float " - "routed_scaling_factor) -> (Tensor, Tensor)"); + "routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, " + "Tensor)"); m.impl("grouped_topk", torch::kCUDA, &grouped_topk); #endif } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index de68b3418..36aab503d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1898,25 +1898,40 @@ def topk_softmax( def grouped_topk( scores: torch.Tensor, - scores_with_bias: torch.Tensor, num_expert_group: int, topk_group: int, topk: int, renormalize: bool, routed_scaling_factor: float, + bias: torch.Tensor, + scoring_func: int = 0, ): + """ + Perform grouped top-k routing for mixture of experts. + + Args: + scores: Raw inputs (logits if scoring_func=1, scores if scoring_func=0) + num_expert_group: Number of expert groups + topk_group: Number of groups to select + topk: Number of experts to select per token + renormalize: Whether to renormalize the output weights + routed_scaling_factor: Scaling factor for routing weights + bias: Bias tensor (e_score_correction_bias). Always fused in kernel. + scoring_func: 0=none (no activation), 1=sigmoid + """ if not current_platform.is_cuda(): raise NotImplementedError( "The fused grouped_topk kernel is only available on CUDA platforms" ) return torch.ops._moe_C.grouped_topk( scores, - scores_with_bias, num_expert_group, topk_group, topk, renormalize, routed_scaling_factor, + bias, + scoring_func, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d0f5eb498..b7415148d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1330,24 +1330,37 @@ def fused_grouped_topk( ) -> tuple[torch.Tensor, torch.Tensor]: assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" - if scoring_func == "softmax": + if scoring_func == "sigmoid": + # Fully fused kernel path for sigmoid + topk_values, topk_indices = ops.grouped_topk( + gating_output, # raw logits + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + e_score_correction_bias.to(gating_output.dtype), + 1, # scoring_func=1 for sigmoid + ) + elif scoring_func == "softmax": + # Apply softmax in Python, then use fused kernel + # TODO: Add support for softmax in kernel scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() + topk_values, topk_indices = ops.grouped_topk( + scores, # pre-computed scores + num_expert_group, + topk_group, + topk, + renormalize, + routed_scaling_factor, + e_score_correction_bias.to(gating_output.dtype), + 0, # scoring_func=0 (no activation, scores already computed) + ) else: raise ValueError(f"Unsupported scoring function: {scoring_func}") - scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) - topk_values, topk_indices = ops.grouped_topk( - scores, - scores_with_bias.to(scores.dtype), - num_expert_group, - topk_group, - topk, - renormalize, - routed_scaling_factor, - ) - return topk_values.to(torch.float32), topk_indices.to(torch.int32) + # Fused kernel outputs float32 values and int32 indices directly + return topk_values, topk_indices def inplace_fused_experts(