[Perf][DeepSeek] Add sigmoid+bias fusion to fused_grouped_topk from TRTLLM (#28124)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Michael Goin
2025-11-08 10:20:55 +08:00
committed by GitHub
parent 61d25dc44b
commit 0852527647
5 changed files with 149 additions and 75 deletions

View File

@@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) {
#endif #endif
} }
// Scoring function enums
enum ScoringFunc {
SCORING_NONE = 0, // no activation function
SCORING_SIGMOID = 1 // apply sigmoid
};
// Efficient sigmoid approximation from TensorRT-LLM
__device__ inline float sigmoid_accurate(float x) {
return 0.5f * tanhf(0.5f * x) + 0.5f;
}
template <typename T> template <typename T>
__device__ void topk_with_k2(T* output, T const* input, __device__ inline T apply_sigmoid(T val) {
float f = cuda_cast<float, T>(val);
return cuda_cast<T, float>(sigmoid_accurate(f));
}
template <typename T>
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
cg::thread_block_tile<32> const& tile, cg::thread_block_tile<32> const& tile,
int32_t const lane_id, int32_t const lane_id,
int const num_experts_per_group) { int const num_experts_per_group,
int const scoring_func) {
// Get the top2 per thread // Get the top2 per thread
T largest = neg_inf<T>(); T largest = neg_inf<T>();
T second_largest = neg_inf<T>(); T second_largest = neg_inf<T>();
@@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input,
if (num_experts_per_group > WARP_SIZE) { if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i]; T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i];
if (value > largest) { if (value > largest) {
second_largest = largest; second_largest = largest;
largest = value; largest = value;
@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
} }
} else { } else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
largest = input[i]; T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
value = value + bias[i];
largest = value;
} }
} }
@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
} }
template <typename T> template <typename T>
__global__ void topk_with_k2_kernel(T* output, T* input, __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
int64_t const num_tokens, int64_t const num_tokens,
int64_t const num_cases, int64_t const num_cases,
int64_t const n_group, int64_t const n_group,
int64_t const num_experts_per_group) { int64_t const num_experts_per_group,
int const scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
if (case_id < num_cases) { if (case_id < num_cases) {
input += case_id * num_experts_per_group; input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group
int32_t group_id = case_id % n_group;
T const* group_bias = bias + group_id * num_experts_per_group;
output += case_id; output += case_id;
cg::thread_block block = cg::this_thread_block(); cg::thread_block block = cg::this_thread_block();
@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); asm volatile("griddepcontrol.wait;");
#endif #endif
topk_with_k2(output, input, tile, lane_id, num_experts_per_group); topk_with_k2(output, input, group_bias, tile, lane_id,
num_experts_per_group, scoring_func);
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;"); asm volatile("griddepcontrol.launch_dependents;");
@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
template <typename T, typename IdxT> template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel( __global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, T const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize, int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) { double routed_scaling_factor, int scoring_func) {
int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts; scores += case_id * num_experts;
group_scores += case_id * n_group; group_scores += case_id * n_group;
topk_values += case_id * topk; topk_values += case_id * topk;
@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
int32_t offset = i_group * num_experts_per_group; int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) { i += WARP_SIZE) {
T candidates = (i < num_experts_per_group) && T candidates = neg_inf<T>();
is_finite(scores_with_bias[offset + i]) if (i < num_experts_per_group) {
? scores_with_bias[offset + i] // Apply scoring function (if any) and add bias
: neg_inf<T>(); T input = scores[offset + i];
if (is_finite(input)) {
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
: input;
candidates = score + bias[offset + i];
}
}
queue.add(candidates, offset + i); queue.add(candidates, offset + i);
} }
if (group_scores[i_group] == topk_group_value) { if (group_scores[i_group] == topk_group_value) {
@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
for (int i = lane_id; for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk); i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) { i += WARP_SIZE) {
T value = T value = cuda_cast<T, float>(0.0f);
i < topk
? scores[s_topk_idx[i]]
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
if (i < topk) { if (i < topk) {
// Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]];
value =
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
s_topk_value[i] = value; s_topk_value[i] = value;
} }
topk_sum += topk_sum +=
@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor; value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
} }
topk_indices[i] = s_topk_idx[i]; topk_indices[i] = s_topk_idx[i];
topk_values[i] = cuda_cast<T, float>(value); topk_values[i] = value;
} }
} else { } else {
for (int i = lane_id; i < topk; i += WARP_SIZE) { for (int i = lane_id; i < topk; i += WARP_SIZE) {
topk_indices[i] = i; topk_indices[i] = i;
topk_values[i] = cuda_cast<T, float>(1.0f / topk); topk_values[i] = 1.0f / topk;
} }
} }
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the // Note: when if_proceed_next_topk==false, choose the first 8 experts as the
@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
} }
template <typename T, typename IdxT> template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T* scores_with_bias, IdxT* topk_indices, T const* bias, int64_t const num_tokens,
int64_t const num_tokens, int64_t const num_experts, int64_t const num_experts, int64_t const n_group,
int64_t const n_group, int64_t const topk_group, int64_t const topk_group, int64_t const topk,
int64_t const topk, bool const renormalize, bool const renormalize, double const routed_scaling_factor,
double const routed_scaling_factor, bool enable_pdl = false, int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) { cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group; int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1; config.numAttrs = 1;
config.attrs = attrs; config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts / n_group); num_tokens, num_cases, n_group, num_experts / n_group,
scoring_func);
int64_t topk_with_k_group_num_blocks = int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
config.numAttrs = 1; config.numAttrs = 1;
config.attrs = attrs; config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, scores_with_bias, num_tokens, topk_values, topk_indices, bias, num_tokens, n_group,
n_group, topk_group, topk, num_experts, topk_group, topk, num_experts, num_experts / n_group,
num_experts / n_group, renormalize, routed_scaling_factor); renormalize, routed_scaling_factor, scoring_func);
} }
#define INSTANTIATE_NOAUX_TC(T, IdxT) \ #define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \ template void invokeNoAuxTc<T, IdxT>( \
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \ T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T * scores_with_bias, int64_t const num_tokens, \ T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const num_experts, int64_t const n_group, \ int64_t const n_group, int64_t const topk_group, int64_t const topk, \
int64_t const topk_group, int64_t const topk, bool const renormalize, \ bool const renormalize, double const routed_scaling_factor, \
double const routed_scaling_factor, bool enable_pdl, \ int const scoring_func, bool enable_pdl, cudaStream_t const stream);
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t); INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t); INSTANTIATE_NOAUX_TC(half, int32_t);
@@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
} // namespace vllm } // namespace vllm
std::tuple<torch::Tensor, torch::Tensor> grouped_topk( std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, torch::Tensor const& scores_with_bias, torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, int64_t topk, bool renormalize, double routed_scaling_factor,
double routed_scaling_factor) { torch::Tensor const& bias, int64_t scoring_func = 0) {
auto data_type = scores_with_bias.scalar_type(); auto data_type = scores.scalar_type();
auto input_size = scores_with_bias.sizes(); auto input_size = scores.sizes();
int64_t num_tokens = input_size[0]; int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1]; int64_t num_experts = input_size[1];
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
TORCH_CHECK(num_experts % n_group == 0, TORCH_CHECK(num_experts % n_group == 0,
"num_experts should be divisible by n_group"); "num_experts should be divisible by n_group");
TORCH_CHECK(n_group <= 32, TORCH_CHECK(n_group <= 32,
"n_group should be smaller than or equal to 32 for now"); "n_group should be smaller than or equal to 32 for now");
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
scoring_func == vllm::moe::SCORING_SIGMOID,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
torch::Tensor group_scores = torch::empty( torch::Tensor group_scores = torch::empty(
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch::Tensor topk_values = torch::empty( torch::Tensor topk_values = torch::empty(
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); {num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor topk_indices = torch::empty( torch::Tensor topk_indices = torch::empty(
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
switch (data_type) { switch (data_type) {
case torch::kFloat16: case torch::kFloat16:
@@ -732,11 +777,11 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
vllm::moe::invokeNoAuxTc<half, int32_t>( vllm::moe::invokeNoAuxTc<half, int32_t>(
reinterpret_cast<half*>(scores.mutable_data_ptr()), reinterpret_cast<half*>(scores.mutable_data_ptr()),
reinterpret_cast<half*>(group_scores.mutable_data_ptr()), reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
reinterpret_cast<half*>(topk_values.mutable_data_ptr()), reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens, reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize, num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream); routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break; break;
case torch::kFloat32: case torch::kFloat32:
// Handle Float32 // Handle Float32
@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
reinterpret_cast<float*>(group_scores.mutable_data_ptr()), reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens, reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize, num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream); routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break; break;
case torch::kBFloat16: case torch::kBFloat16:
// Handle BFloat16 // Handle BFloat16
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
num_tokens, num_experts, n_group, topk_group, topk, renormalize, num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, false, stream); routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break; break;
default: default:
// Handle other data types // Handle other data types

View File

@@ -39,9 +39,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t BLOCK_SIZE_K, int64_t bit); int64_t BLOCK_SIZE_K, int64_t bit);
std::tuple<torch::Tensor, torch::Tensor> grouped_topk( std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, torch::Tensor const& scores_with_bias, torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, int64_t topk, bool renormalize, double routed_scaling_factor,
double routed_scaling_factor); torch::Tensor const& bias, int64_t scoring_func);
#endif #endif
bool moe_permute_unpermute_supported(); bool moe_permute_unpermute_supported();

View File

@@ -107,9 +107,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply grouped topk routing to select experts. // Apply grouped topk routing to select experts.
m.def( m.def(
"grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " "grouped_topk(Tensor scores, int n_group, int "
"topk_group, int topk, bool renormalize, float " "topk_group, int topk, bool renormalize, float "
"routed_scaling_factor) -> (Tensor, Tensor)"); "routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk); m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
#endif #endif
} }

View File

@@ -1898,25 +1898,40 @@ def topk_softmax(
def grouped_topk( def grouped_topk(
scores: torch.Tensor, scores: torch.Tensor,
scores_with_bias: torch.Tensor,
num_expert_group: int, num_expert_group: int,
topk_group: int, topk_group: int,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
routed_scaling_factor: float, routed_scaling_factor: float,
bias: torch.Tensor,
scoring_func: int = 0,
): ):
"""
Perform grouped top-k routing for mixture of experts.
Args:
scores: Raw inputs (logits if scoring_func=1, scores if scoring_func=0)
num_expert_group: Number of expert groups
topk_group: Number of groups to select
topk: Number of experts to select per token
renormalize: Whether to renormalize the output weights
routed_scaling_factor: Scaling factor for routing weights
bias: Bias tensor (e_score_correction_bias). Always fused in kernel.
scoring_func: 0=none (no activation), 1=sigmoid
"""
if not current_platform.is_cuda(): if not current_platform.is_cuda():
raise NotImplementedError( raise NotImplementedError(
"The fused grouped_topk kernel is only available on CUDA platforms" "The fused grouped_topk kernel is only available on CUDA platforms"
) )
return torch.ops._moe_C.grouped_topk( return torch.ops._moe_C.grouped_topk(
scores, scores,
scores_with_bias,
num_expert_group, num_expert_group,
topk_group, topk_group,
topk, topk,
renormalize, renormalize,
routed_scaling_factor, routed_scaling_factor,
bias,
scoring_func,
) )

View File

@@ -1330,24 +1330,37 @@ def fused_grouped_topk(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
if scoring_func == "softmax": if scoring_func == "sigmoid":
# Fully fused kernel path for sigmoid
topk_values, topk_indices = ops.grouped_topk(
gating_output, # raw logits
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype),
1, # scoring_func=1 for sigmoid
)
elif scoring_func == "softmax":
# Apply softmax in Python, then use fused kernel
# TODO: Add support for softmax in kernel
scores = torch.softmax(gating_output, dim=-1) scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid": topk_values, topk_indices = ops.grouped_topk(
scores = gating_output.sigmoid() scores, # pre-computed scores
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
e_score_correction_bias.to(gating_output.dtype),
0, # scoring_func=0 (no activation, scores already computed)
)
else: else:
raise ValueError(f"Unsupported scoring function: {scoring_func}") raise ValueError(f"Unsupported scoring function: {scoring_func}")
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0) # Fused kernel outputs float32 values and int32 indices directly
topk_values, topk_indices = ops.grouped_topk( return topk_values, topk_indices
scores,
scores_with_bias.to(scores.dtype),
num_expert_group,
topk_group,
topk,
renormalize,
routed_scaling_factor,
)
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
def inplace_fused_experts( def inplace_fused_experts(