[Kernel] Optimize grouped topk kernel (#34206)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-02-20 01:34:45 -08:00
committed by GitHub
parent 8de7c636cc
commit b1c4f0b265
3 changed files with 642 additions and 99 deletions

View File

@@ -1,6 +1,6 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
@@ -17,8 +17,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "moeTopKFuncs.cuh"
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda/std/limits>
@@ -30,7 +32,17 @@ namespace vllm {
namespace moe {
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
static constexpr int WARP_SIZE = 32;
static constexpr int NumNemotronExperts = 512;
static constexpr int NumKimiK2Experts = 384;
static constexpr int NumDeepseekExperts = 256;
static constexpr int MaxSupportedExpertCount =
std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
static constexpr int MaxNumExpertsUnit = 128;
static constexpr int NumTopGroupScores = 2;
static constexpr int DefaultMaxNumTopExperts = 8;
static constexpr int MaxSupportedTopExperts = 22;
static constexpr int MaxNumTopGroups = 4;
namespace warp_topk {
@@ -657,76 +669,335 @@ __global__ void grouped_topk_fused_kernel(
#endif
}
template <typename T, typename BiasT, typename IdxT>
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
int MaxNumExperts, bool UseGroups,
int MaxNumTopExperts = DefaultMaxNumTopExperts>
__global__ void grouped_topk_fused_small_expert_count_kernel(
T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias,
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup,
int64_t const topk, int64_t const numExperts,
int64_t const numExpertsPerGroup, bool const renormalize,
double const routedScalingFactor) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
// declare shared memory structure
// number of experts is bounded by number of threads
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts];
__shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts];
// number of expert groups is bounded by number of warps
int constexpr NumWarps = MaxNumExperts / WARP_SIZE;
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
// needed for warp reduce
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
// for the final reduction of weight norm, only some lanes need to participate
int32_t laneIdx = threadIdx.x % WARP_SIZE;
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
if constexpr (UseGroups) {
if (warpIdx >= numGroup) {
return;
}
}
// note that for invalid scores, we simply use a negative value:
// they work well even with the compacted format used in topK, and
// sigmoid / bias activated scores cannot be negative
const float invalidScoreFloat = float{-INFINITY};
// load bias already; each warp represents one expert group
auto threadExpert = threadIdx.x;
bool expertSelected = threadExpert < numExperts;
if constexpr (UseGroups) {
threadExpert = warpIdx * numExpertsPerGroup + laneIdx;
expertSelected = laneIdx < numExpertsPerGroup;
}
auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert;
auto biasVal = expertSelected ? static_cast<float>(routingBias[threadExpert])
: invalidScoreFloat;
topkValues += blockIdx.x * topk;
topkIndices += blockIdx.x * topk;
// get our assigned thread score; each warp represents one expert group
float score =
expertSelected ? static_cast<float>(scores[scoreIdx]) : invalidScoreFloat;
auto scoreSigmoid = apply_scoring<SF>(score);
// write the sigmoid score to shared for later use
if (expertSelected) {
smemScoreSigmoid[threadExpert] = scoreSigmoid;
}
// get the score with bias
// note that with invalid values, because sigmoid is < 1 and bias is -1,
// we must get a negative value, which is smaller than any valid value
auto scoreBias = float{scoreSigmoid + float{biasVal}};
if (expertSelected) {
smemScoreBias[threadExpert] = scoreBias;
}
// registers for top group score reduction
float topExpGroupScores[NumTopGroupScores];
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
float topGroups[MaxNumTopGroups]; // bound of numGroup
int32_t topGroupIdx[MaxNumTopGroups];
float expertScoreGroup[MaxNumTopGroups];
int32_t expertIdxGroup[MaxNumTopGroups];
float topScores[MaxNumTopExperts]; // bound of topk
int32_t topExperts[MaxNumTopExperts];
if constexpr (UseGroups) {
reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias,
threadExpert,
/* minValue */ invalidScoreFloat);
// get the final group score and write it to shared
if (warp.thread_rank() == 0) {
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
smemGroupScores[warpIdx] = groupScore;
}
}
// make group scores available to all warps
__syncthreads();
if constexpr (UseGroups) {
if (warpIdx == 0) {
// a single warp performs the selection of top groups, and goes on to
// select the final experts
float groupScore =
laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat;
reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
/* minValue */ invalidScoreFloat);
// final expert selection: get relevant indexes and scores from shared
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup
auto groupIdx = topGroupIdx[ii];
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;
expertScoreGroup[ii] = (ii < topkGroup) && expertSelected
? smemScoreBias[expertIdxGroup[ii]]
: invalidScoreFloat;
}
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup, /* minValue */ invalidScoreFloat,
topk);
}
} else if constexpr (MaxNumExperts > MaxNumExpertsUnit) {
// without groups, and the expert number is larger than MaxNumExpertsUnit,
// we need to use multiple warps to calculate the intermediate topk results
int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1;
int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts;
__shared__ float
__attribute((aligned(128))) smemInterTopScores[NumInterTopK];
__shared__ int32_t
__attribute((aligned(128))) smemInterTopExperts[NumInterTopK];
if (warpIdx < NumExpertWarps) {
int offset = warpIdx * WARP_SIZE * MaxNumTopGroups;
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
auto expertIdx = ii * WARP_SIZE + laneIdx;
expertIdxGroup[ii] = offset + expertIdx;
expertScoreGroup[ii] = offset + expertIdx < numExperts
? smemScoreBias[offset + expertIdx]
: invalidScoreFloat;
}
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup,
/* minValue */ invalidScoreFloat, topk);
if (laneIdx < topk) {
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
topScores[laneIdx];
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
topExperts[laneIdx];
} else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) {
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
invalidScoreFloat;
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
MaxNumExperts - 1;
}
}
__syncthreads();
if (warpIdx == 0) {
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
float intermediateScore[NumInterTopKPerThread];
int32_t intermediateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE;
i += WARP_SIZE) {
int ii = i / WARP_SIZE;
if (i < NumInterTopK) {
intermediateScore[ii] = smemInterTopScores[i];
intermediateExpert[ii] = smemInterTopExperts[i];
} else {
intermediateScore[ii] = invalidScoreFloat;
intermediateExpert[ii] = MaxNumExperts - 1;
}
}
reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore,
intermediateExpert,
/* minValue */ invalidScoreFloat, topk);
}
} else {
// without groups, and the expert number is smaller than MaxNumExpertsUnit
// each thread just takes `MaxNumTopGroups` experts
if (warpIdx == 0) {
#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
auto expertIdx = ii * WARP_SIZE + laneIdx;
expertIdxGroup[ii] = expertIdx;
expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx]
: invalidScoreFloat;
}
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup,
/* minValue */ invalidScoreFloat, topk);
}
}
if (warpIdx == 0) {
// determine our lane's expert index and write to output
int32_t expertIdx =
laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1;
float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F;
float finalScore = static_cast<float>(scoreNorm * routedScalingFactor);
// norm the value
if (renormalize) {
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
finalScore /= (redNorm + 1e-20);
}
// store the topk scores and experts to output
if (laneIdx < topk) {
topkValues[laneIdx] = finalScore;
topkIndices[laneIdx] = expertIdx;
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices,
BiasT const* bias, int64_t const num_tokens,
int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk,
bool const renormalize, double const routed_scaling_factor,
int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) {
bool enable_pdl = false, cudaStream_t const stream = 0) {
cudaLaunchConfig_t config;
// One block per token; one warp per group.
config.gridDim = static_cast<uint32_t>(num_tokens);
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
int32_t const num_warps = static_cast<int32_t>(n_group);
size_t const val_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
size_t const val_bytes_aligned =
warp_topk::round_up_to_multiple_of<256>(val_bytes);
size_t const idx_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
config.dynamicSmemBytes = internal_bytes + extra_bytes;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
auto const sf = static_cast<ScoringFunc>(scoring_func);
switch (sf) {
case SCORING_NONE: {
auto* kernel_instance =
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_NONE>;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
return;
// Check if we can use the optimized
// grouped_topk_fused_small_expert_count_kernel
bool const is_single_group =
(n_group == 1) && (topk_group == 1) &&
(num_experts <= MaxSupportedExpertCount) &&
(topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts);
int64_t const experts_per_group = num_experts / n_group;
bool const is_multi_group =
(n_group > 1) && (num_experts <= NumDeepseekExperts) &&
(experts_per_group <= WARP_SIZE) &&
(experts_per_group * topk_group <= MaxNumExpertsUnit) &&
(topk <= DefaultMaxNumTopExperts) && (topk_group <= MaxNumTopGroups);
if (is_single_group || is_multi_group) {
auto* kernel_instance =
&grouped_topk_fused_small_expert_count_kernel<T, BiasT, IdxT, SF,
NumDeepseekExperts, true>;
int num_threads = NumDeepseekExperts;
if (is_single_group) {
// Special case for Nemotron, which selects top 22 from 512 experts, and 1
// group only.
if (num_experts == NumNemotronExperts && n_group == 1 &&
topk == MaxSupportedTopExperts) {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, NumNemotronExperts, false,
MaxSupportedTopExperts>;
num_threads = NumNemotronExperts;
} else if (num_experts > NumKimiK2Experts &&
num_experts <= MaxSupportedExpertCount) {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, MaxSupportedExpertCount, false>;
num_threads = MaxSupportedExpertCount;
} else if (num_experts > MaxNumExpertsUnit &&
num_experts <= NumKimiK2Experts) {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, NumKimiK2Experts, false>;
num_threads = NumKimiK2Experts;
} else {
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
T, BiasT, IdxT, SF, MaxNumExpertsUnit, false>;
num_threads = MaxNumExpertsUnit;
}
}
case SCORING_SIGMOID: {
auto* kernel_instance =
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_SIGMOID>;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
return;
}
default:
// should be guarded by higher level checks.
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
config.gridDim = num_tokens;
config.blockDim = num_threads;
config.dynamicSmemBytes = 0;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, n_group, topk_group,
topk, num_experts, num_experts / n_group, renormalize,
routed_scaling_factor);
} else {
auto* kernel_instance = &grouped_topk_fused_kernel<T, BiasT, IdxT, SF>;
// One block per token; one warp per group.
config.gridDim = static_cast<uint32_t>(num_tokens);
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
int32_t const num_warps = static_cast<int32_t>(n_group);
size_t const val_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
size_t const val_bytes_aligned =
warp_topk::round_up_to_multiple_of<256>(val_bytes);
size_t const idx_bytes =
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
config.dynamicSmemBytes = internal_bytes + extra_bytes;
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
topk_indices, bias, num_tokens, num_experts, n_group,
topk_group, topk, renormalize, routed_scaling_factor);
}
}
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
template void invokeNoAuxTc<T, BiasT, IdxT>( \
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \
template void invokeNoAuxTc<T, BiasT, IdxT, SF>( \
T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \
int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, float, int32_t);
INSTANTIATE_NOAUX_TC(float, half, int32_t);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(half, float, int32_t);
INSTANTIATE_NOAUX_TC(half, half, int32_t);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_SIGMOID);
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_NONE);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_NONE);
} // end namespace moe
} // namespace vllm
@@ -762,46 +1033,53 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
auto const sf = static_cast<vllm::moe::ScoringFunc>(scoring_func);
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kFloat32: \
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
case torch::kBFloat16: \
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, static_cast<int>(scoring_func), false, \
stream); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \
do { \
switch (sf) { \
case vllm::moe::SCORING_NONE: \
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_NONE>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, false, stream); \
break; \
case vllm::moe::SCORING_SIGMOID: \
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_SIGMOID>( \
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
num_experts, n_group, topk_group, topk, renormalize, \
routed_scaling_factor, false, stream); \
break; \
default: \
throw std::invalid_argument("Unsupported scoring_func"); \
break; \
} \
} while (0)
#define LAUNCH_KERNEL(T, IdxT) \
do { \
switch (bias_type) { \
case torch::kFloat16: \
LAUNCH_KERNEL_SF(T, half, IdxT); \
break; \
case torch::kFloat32: \
LAUNCH_KERNEL_SF(T, float, IdxT); \
break; \
case torch::kBFloat16: \
LAUNCH_KERNEL_SF(T, __nv_bfloat16, IdxT); \
break; \
default: \
throw std::invalid_argument( \
"Invalid bias dtype, only supports float16, float32, and " \
"bfloat16"); \
break; \
} \
} while (0)
switch (data_type) {
@@ -824,5 +1102,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
break;
}
#undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_SF
return {topk_values, topk_indices};
}

257
csrc/moe/moeTopKFuncs.cuh Normal file
View File

@@ -0,0 +1,257 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
* Copyright (c) 2026, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION. All rights
* reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
namespace vllm {
namespace moe {
namespace reduce_topk {
namespace cg = cooperative_groups;
static constexpr int kWARP_SIZE = 32;
template <typename T_>
struct TopKRedType {
using T = T_;
static_assert(
std::is_same_v<T, float> || std::is_same_v<T, half> ||
std::is_same_v<T, __nv_bfloat16> || std::is_same_v<T, int>,
"Top K reduction only implemented for int, float, float16 and bfloat16");
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
static constexpr int kMaxIdx = 65535;
TypeCmp compValIdx;
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) {
auto valueBits = cub::Traits<T>::TwiddleIn(
reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
TypeCmp compactTmp = valueBits;
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
// Use 65535 minus idx to give higher priority to elements with smaller
// indices.
return compactTmp;
}
static __host__ __device__ void unpack(T& value, int32_t& index,
TypeCmp cmp) {
// Since “65535-idx” is always smaller than 65536 and positive, we can
// directly use it as the lower 16 bits
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));
auto compactTmp = cmp >> kMoveBits;
auto valueBits = cub::Traits<T>::TwiddleOut(
reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
value = reinterpret_cast<T&>(valueBits);
}
__host__ __device__ TopKRedType() = default;
__host__ __device__ TopKRedType(T val, int32_t idx)
: compValIdx(makeCmpVal(val, idx)) {}
__host__ __device__ operator TypeCmp() const noexcept { return compValIdx; }
__device__ inline TypeCmp reduce(
cg::thread_block_tile<kWARP_SIZE> const& warp) {
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int K_, bool Enable_>
struct TopKIdx {
// by default, empty
};
template <int K_>
struct TopKIdx<K_, true> {
static constexpr int K = K_;
int32_t val[K];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#define TOPK_SWAP(I, J) \
{ \
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
topK[I].compValIdx = pairMax; \
topK[J].compValIdx = pairMin; \
}
template <int N, typename RedType>
struct Sort;
template <typename RedType>
struct Sort<1, RedType> {
static __device__ void run(RedType* topK) {}
};
template <typename RedType>
struct Sort<2, RedType> {
static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); }
};
template <typename RedType>
struct Sort<3, RedType> {
static __device__ void run(RedType* topK) {
TOPK_SWAP(0, 1);
TOPK_SWAP(1, 2);
TOPK_SWAP(0, 1);
}
};
template <typename RedType>
struct Sort<4, RedType> {
static __device__ void run(RedType* topK) {
TOPK_SWAP(0, 2);
TOPK_SWAP(1, 3);
TOPK_SWAP(0, 1);
TOPK_SWAP(2, 3);
TOPK_SWAP(1, 2);
}
};
template <int K, typename Type>
__forceinline__ __device__ void reduceTopK(
cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue,
int actualK = K) {
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
using RedType = TopKRedType<Type>;
RedType topK{value, idx};
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk) {
topK =
kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
// get the next largest value
packedMax = topK.reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};
template <int K, typename Type, int N, bool IsSorted = false>
__device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp,
Type (&out)[K], int32_t (&outIdx)[K],
Type (&value)[N], int32_t (&idx)[N],
Type minValue, int actualK = K) {
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N < 5,
"Only support candidates number less than or equal to 128");
using RedType = TopKRedType<Type>;
RedType topK[N];
#pragma unroll
for (int nn = 0; nn < N; ++nn) {
topK[nn] = RedType{value[nn], idx[nn]};
}
if constexpr (!IsSorted) {
Sort<N, RedType>::run(topK);
}
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < actualK; ++kk) {
bool update = kk > 0 && packedMax == topK[0].compValIdx;
#pragma unroll
for (int nn = 0; nn < N; ++nn) {
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]}
: update ? topK[nn + 1]
: topK[nn];
}
// get the next largest value
packedMax = topK[0].reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};
template <int K, typename Type, int N>
__forceinline__ __device__ void reduceTopK(
cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N],
Type const minValue, int actualK = K) {
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(
N <= 16,
"Only support candidates number less than or equal to 16*32=512");
static_assert(N <= 4 || N % 4 == 0,
"Only support candidates number is a multiple of 4*32=128 or "
"less than or equal to 4");
using RedType = TopKRedType<Type>;
if constexpr (N <= 4) {
reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue,
actualK);
} else {
constexpr int numLoops = N / 4;
constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1;
Type topKBufferValue[numResults];
int32_t topKBufferIdx[numResults];
int32_t laneIdx = threadIdx.x % kWARP_SIZE;
for (int ii = 0; ii < numResults; ++ii) {
topKBufferValue[ii] = minValue;
topKBufferIdx[ii] = ii * kWARP_SIZE - 1;
}
for (int loop = 0; loop < numLoops; ++loop) {
int start = loop * 4;
Type topKValue[K];
int32_t topKIdx[K];
Type inValue[4];
int32_t inIdx[4];
for (int i = 0; i < 4; ++i) {
inValue[i] = value[start + i];
inIdx[i] = idx[start + i];
}
reduceTopKFunc<K, Type, 4>(warp, topKValue, topKIdx, inValue, inIdx,
minValue, actualK);
int inOffset = laneIdx % K;
if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) {
topKBufferValue[0] = topKValue[inOffset];
topKBufferIdx[0] = topKIdx[inOffset];
}
if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) {
topKBufferValue[1] = topKValue[inOffset];
topKBufferIdx[1] = topKIdx[inOffset];
}
}
reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue,
topKBufferIdx, minValue, actualK);
}
};
#undef TOPK_SWAP
} // namespace reduce_topk
} // namespace moe
} // namespace vllm

View File

@@ -8,6 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
import pytest
import torch
import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm.config import (
CompilationConfig,
VllmConfig,
@@ -27,11 +28,17 @@ from vllm.utils.torch_utils import set_random_seed
)
@pytest.mark.parametrize("n_token", [1, 33, 64])
@pytest.mark.parametrize("n_hidden", [1024, 2048])
@pytest.mark.parametrize("n_expert", [16])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize(
"n_expert,topk,num_expert_group,topk_group",
[
(16, 2, 8, 2),
(128, 2, 8, 2),
(256, 8, 8, 4),
(384, 8, 1, 1),
(512, 22, 1, 1),
],
)
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("num_expert_group", [8])
@pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
@@ -42,9 +49,9 @@ def test_grouped_topk(
n_hidden: int,
n_expert: int,
topk: int,
renormalize: bool,
num_expert_group: int,
topk_group: int,
renormalize: bool,
scoring_func: str,
routed_scaling_factor: float,
input_dtype: torch.dtype,
@@ -62,6 +69,7 @@ def test_grouped_topk(
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
grouped_topk = GroupedTopk(
topk=topk,
renormalize=renormalize,
@@ -89,8 +97,7 @@ def test_grouped_topk(
e_score_correction_bias=e_score_correction_bias,
)
if renormalize:
torch.testing.assert_close(
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
)
torch.testing.assert_close(
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
)
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)