[Kernel] Optimize grouped topk kernel (#34206)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
@@ -17,8 +17,10 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "moeTopKFuncs.cuh"
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <cmath>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda/std/limits>
|
||||
@@ -30,7 +32,17 @@ namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int NumNemotronExperts = 512;
|
||||
static constexpr int NumKimiK2Experts = 384;
|
||||
static constexpr int NumDeepseekExperts = 256;
|
||||
static constexpr int MaxSupportedExpertCount =
|
||||
std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
|
||||
static constexpr int MaxNumExpertsUnit = 128;
|
||||
static constexpr int NumTopGroupScores = 2;
|
||||
static constexpr int DefaultMaxNumTopExperts = 8;
|
||||
static constexpr int MaxSupportedTopExperts = 22;
|
||||
static constexpr int MaxNumTopGroups = 4;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
@@ -657,76 +669,335 @@ __global__ void grouped_topk_fused_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT>
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
|
||||
int MaxNumExperts, bool UseGroups,
|
||||
int MaxNumTopExperts = DefaultMaxNumTopExperts>
|
||||
__global__ void grouped_topk_fused_small_expert_count_kernel(
|
||||
T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias,
|
||||
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup,
|
||||
int64_t const topk, int64_t const numExperts,
|
||||
int64_t const numExpertsPerGroup, bool const renormalize,
|
||||
double const routedScalingFactor) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
// declare shared memory structure
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts];
|
||||
__shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts];
|
||||
// number of expert groups is bounded by number of warps
|
||||
int constexpr NumWarps = MaxNumExperts / WARP_SIZE;
|
||||
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
|
||||
|
||||
// needed for warp reduce
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
// for the final reduction of weight norm, only some lanes need to participate
|
||||
int32_t laneIdx = threadIdx.x % WARP_SIZE;
|
||||
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
|
||||
|
||||
if constexpr (UseGroups) {
|
||||
if (warpIdx >= numGroup) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// note that for invalid scores, we simply use a negative value:
|
||||
// they work well even with the compacted format used in topK, and
|
||||
// sigmoid / bias activated scores cannot be negative
|
||||
const float invalidScoreFloat = float{-INFINITY};
|
||||
|
||||
// load bias already; each warp represents one expert group
|
||||
auto threadExpert = threadIdx.x;
|
||||
bool expertSelected = threadExpert < numExperts;
|
||||
if constexpr (UseGroups) {
|
||||
threadExpert = warpIdx * numExpertsPerGroup + laneIdx;
|
||||
expertSelected = laneIdx < numExpertsPerGroup;
|
||||
}
|
||||
|
||||
auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert;
|
||||
auto biasVal = expertSelected ? static_cast<float>(routingBias[threadExpert])
|
||||
: invalidScoreFloat;
|
||||
topkValues += blockIdx.x * topk;
|
||||
topkIndices += blockIdx.x * topk;
|
||||
|
||||
// get our assigned thread score; each warp represents one expert group
|
||||
float score =
|
||||
expertSelected ? static_cast<float>(scores[scoreIdx]) : invalidScoreFloat;
|
||||
auto scoreSigmoid = apply_scoring<SF>(score);
|
||||
// write the sigmoid score to shared for later use
|
||||
if (expertSelected) {
|
||||
smemScoreSigmoid[threadExpert] = scoreSigmoid;
|
||||
}
|
||||
|
||||
// get the score with bias
|
||||
// note that with invalid values, because sigmoid is < 1 and bias is -1,
|
||||
// we must get a negative value, which is smaller than any valid value
|
||||
auto scoreBias = float{scoreSigmoid + float{biasVal}};
|
||||
|
||||
if (expertSelected) {
|
||||
smemScoreBias[threadExpert] = scoreBias;
|
||||
}
|
||||
|
||||
// registers for top group score reduction
|
||||
float topExpGroupScores[NumTopGroupScores];
|
||||
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
|
||||
float topGroups[MaxNumTopGroups]; // bound of numGroup
|
||||
int32_t topGroupIdx[MaxNumTopGroups];
|
||||
float expertScoreGroup[MaxNumTopGroups];
|
||||
int32_t expertIdxGroup[MaxNumTopGroups];
|
||||
float topScores[MaxNumTopExperts]; // bound of topk
|
||||
int32_t topExperts[MaxNumTopExperts];
|
||||
|
||||
if constexpr (UseGroups) {
|
||||
reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias,
|
||||
threadExpert,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
|
||||
// get the final group score and write it to shared
|
||||
if (warp.thread_rank() == 0) {
|
||||
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
|
||||
smemGroupScores[warpIdx] = groupScore;
|
||||
}
|
||||
}
|
||||
|
||||
// make group scores available to all warps
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (UseGroups) {
|
||||
if (warpIdx == 0) {
|
||||
// a single warp performs the selection of top groups, and goes on to
|
||||
// select the final experts
|
||||
float groupScore =
|
||||
laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat;
|
||||
|
||||
reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
// final expert selection: get relevant indexes and scores from shared
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup
|
||||
auto groupIdx = topGroupIdx[ii];
|
||||
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;
|
||||
|
||||
expertScoreGroup[ii] = (ii < topkGroup) && expertSelected
|
||||
? smemScoreBias[expertIdxGroup[ii]]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup, /* minValue */ invalidScoreFloat,
|
||||
topk);
|
||||
}
|
||||
} else if constexpr (MaxNumExperts > MaxNumExpertsUnit) {
|
||||
// without groups, and the expert number is larger than MaxNumExpertsUnit,
|
||||
// we need to use multiple warps to calculate the intermediate topk results
|
||||
|
||||
int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1;
|
||||
int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts;
|
||||
__shared__ float
|
||||
__attribute((aligned(128))) smemInterTopScores[NumInterTopK];
|
||||
__shared__ int32_t
|
||||
__attribute((aligned(128))) smemInterTopExperts[NumInterTopK];
|
||||
if (warpIdx < NumExpertWarps) {
|
||||
int offset = warpIdx * WARP_SIZE * MaxNumTopGroups;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
|
||||
auto expertIdx = ii * WARP_SIZE + laneIdx;
|
||||
expertIdxGroup[ii] = offset + expertIdx;
|
||||
expertScoreGroup[ii] = offset + expertIdx < numExperts
|
||||
? smemScoreBias[offset + expertIdx]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
|
||||
if (laneIdx < topk) {
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
topScores[laneIdx];
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
topExperts[laneIdx];
|
||||
} else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) {
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
invalidScoreFloat;
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
MaxNumExperts - 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (warpIdx == 0) {
|
||||
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
|
||||
float intermediateScore[NumInterTopKPerThread];
|
||||
int32_t intermediateExpert[NumInterTopKPerThread];
|
||||
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE;
|
||||
i += WARP_SIZE) {
|
||||
int ii = i / WARP_SIZE;
|
||||
if (i < NumInterTopK) {
|
||||
intermediateScore[ii] = smemInterTopScores[i];
|
||||
intermediateExpert[ii] = smemInterTopExperts[i];
|
||||
} else {
|
||||
intermediateScore[ii] = invalidScoreFloat;
|
||||
intermediateExpert[ii] = MaxNumExperts - 1;
|
||||
}
|
||||
}
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore,
|
||||
intermediateExpert,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
}
|
||||
} else {
|
||||
// without groups, and the expert number is smaller than MaxNumExpertsUnit
|
||||
// each thread just takes `MaxNumTopGroups` experts
|
||||
if (warpIdx == 0) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
|
||||
auto expertIdx = ii * WARP_SIZE + laneIdx;
|
||||
expertIdxGroup[ii] = expertIdx;
|
||||
expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
}
|
||||
}
|
||||
|
||||
if (warpIdx == 0) {
|
||||
// determine our lane's expert index and write to output
|
||||
int32_t expertIdx =
|
||||
laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1;
|
||||
float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F;
|
||||
float finalScore = static_cast<float>(scoreNorm * routedScalingFactor);
|
||||
// norm the value
|
||||
if (renormalize) {
|
||||
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
|
||||
finalScore /= (redNorm + 1e-20);
|
||||
}
|
||||
// store the topk scores and experts to output
|
||||
if (laneIdx < topk) {
|
||||
topkValues[laneIdx] = finalScore;
|
||||
topkIndices[laneIdx] = expertIdx;
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
|
||||
void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices,
|
||||
BiasT const* bias, int64_t const num_tokens,
|
||||
int64_t const num_experts, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk,
|
||||
bool const renormalize, double const routed_scaling_factor,
|
||||
int const scoring_func, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
bool enable_pdl = false, cudaStream_t const stream = 0) {
|
||||
cudaLaunchConfig_t config;
|
||||
// One block per token; one warp per group.
|
||||
config.gridDim = static_cast<uint32_t>(num_tokens);
|
||||
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
|
||||
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
|
||||
int32_t const num_warps = static_cast<int32_t>(n_group);
|
||||
size_t const val_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
|
||||
size_t const val_bytes_aligned =
|
||||
warp_topk::round_up_to_multiple_of<256>(val_bytes);
|
||||
size_t const idx_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
|
||||
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
|
||||
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
|
||||
config.dynamicSmemBytes = internal_bytes + extra_bytes;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
auto const sf = static_cast<ScoringFunc>(scoring_func);
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_NONE>;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
return;
|
||||
|
||||
// Check if we can use the optimized
|
||||
// grouped_topk_fused_small_expert_count_kernel
|
||||
bool const is_single_group =
|
||||
(n_group == 1) && (topk_group == 1) &&
|
||||
(num_experts <= MaxSupportedExpertCount) &&
|
||||
(topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts);
|
||||
|
||||
int64_t const experts_per_group = num_experts / n_group;
|
||||
bool const is_multi_group =
|
||||
(n_group > 1) && (num_experts <= NumDeepseekExperts) &&
|
||||
(experts_per_group <= WARP_SIZE) &&
|
||||
(experts_per_group * topk_group <= MaxNumExpertsUnit) &&
|
||||
(topk <= DefaultMaxNumTopExperts) && (topk_group <= MaxNumTopGroups);
|
||||
|
||||
if (is_single_group || is_multi_group) {
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_small_expert_count_kernel<T, BiasT, IdxT, SF,
|
||||
NumDeepseekExperts, true>;
|
||||
int num_threads = NumDeepseekExperts;
|
||||
if (is_single_group) {
|
||||
// Special case for Nemotron, which selects top 22 from 512 experts, and 1
|
||||
// group only.
|
||||
if (num_experts == NumNemotronExperts && n_group == 1 &&
|
||||
topk == MaxSupportedTopExperts) {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, NumNemotronExperts, false,
|
||||
MaxSupportedTopExperts>;
|
||||
num_threads = NumNemotronExperts;
|
||||
} else if (num_experts > NumKimiK2Experts &&
|
||||
num_experts <= MaxSupportedExpertCount) {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, MaxSupportedExpertCount, false>;
|
||||
num_threads = MaxSupportedExpertCount;
|
||||
} else if (num_experts > MaxNumExpertsUnit &&
|
||||
num_experts <= NumKimiK2Experts) {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, NumKimiK2Experts, false>;
|
||||
num_threads = NumKimiK2Experts;
|
||||
} else {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, MaxNumExpertsUnit, false>;
|
||||
num_threads = MaxNumExpertsUnit;
|
||||
}
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_SIGMOID>;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
return;
|
||||
}
|
||||
default:
|
||||
// should be guarded by higher level checks.
|
||||
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||
config.gridDim = num_tokens;
|
||||
config.blockDim = num_threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, n_group, topk_group,
|
||||
topk, num_experts, num_experts / n_group, renormalize,
|
||||
routed_scaling_factor);
|
||||
} else {
|
||||
auto* kernel_instance = &grouped_topk_fused_kernel<T, BiasT, IdxT, SF>;
|
||||
// One block per token; one warp per group.
|
||||
config.gridDim = static_cast<uint32_t>(num_tokens);
|
||||
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
|
||||
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
|
||||
int32_t const num_warps = static_cast<int32_t>(n_group);
|
||||
size_t const val_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
|
||||
size_t const val_bytes_aligned =
|
||||
warp_topk::round_up_to_multiple_of<256>(val_bytes);
|
||||
size_t const idx_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
|
||||
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
|
||||
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
|
||||
config.dynamicSmemBytes = internal_bytes + extra_bytes;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT>( \
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT, SF>( \
|
||||
T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \
|
||||
int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
|
||||
bool enable_pdl, cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_NONE);
|
||||
} // end namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
@@ -762,46 +1033,53 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
|
||||
auto const sf = static_cast<vllm::moe::ScoringFunc>(scoring_func);
|
||||
|
||||
#define LAUNCH_KERNEL(T, IdxT) \
|
||||
do { \
|
||||
switch (bias_type) { \
|
||||
case torch::kFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
case torch::kFloat32: \
|
||||
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
case torch::kBFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
|
||||
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument( \
|
||||
"Invalid bias dtype, only supports float16, float32, and " \
|
||||
"bfloat16"); \
|
||||
break; \
|
||||
} \
|
||||
#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \
|
||||
do { \
|
||||
switch (sf) { \
|
||||
case vllm::moe::SCORING_NONE: \
|
||||
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_NONE>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, false, stream); \
|
||||
break; \
|
||||
case vllm::moe::SCORING_SIGMOID: \
|
||||
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_SIGMOID>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, false, stream); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument("Unsupported scoring_func"); \
|
||||
break; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LAUNCH_KERNEL(T, IdxT) \
|
||||
do { \
|
||||
switch (bias_type) { \
|
||||
case torch::kFloat16: \
|
||||
LAUNCH_KERNEL_SF(T, half, IdxT); \
|
||||
break; \
|
||||
case torch::kFloat32: \
|
||||
LAUNCH_KERNEL_SF(T, float, IdxT); \
|
||||
break; \
|
||||
case torch::kBFloat16: \
|
||||
LAUNCH_KERNEL_SF(T, __nv_bfloat16, IdxT); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument( \
|
||||
"Invalid bias dtype, only supports float16, float32, and " \
|
||||
"bfloat16"); \
|
||||
break; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
switch (data_type) {
|
||||
@@ -824,5 +1102,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
break;
|
||||
}
|
||||
#undef LAUNCH_KERNEL
|
||||
#undef LAUNCH_KERNEL_SF
|
||||
return {topk_values, topk_indices};
|
||||
}
|
||||
|
||||
257
csrc/moe/moeTopKFuncs.cuh
Normal file
257
csrc/moe/moeTopKFuncs.cuh
Normal file
@@ -0,0 +1,257 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
|
||||
* Copyright (c) 2026, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION. All rights
|
||||
* reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
namespace reduce_topk {
|
||||
namespace cg = cooperative_groups;
|
||||
static constexpr int kWARP_SIZE = 32;
|
||||
|
||||
template <typename T_>
|
||||
struct TopKRedType {
|
||||
using T = T_;
|
||||
static_assert(
|
||||
std::is_same_v<T, float> || std::is_same_v<T, half> ||
|
||||
std::is_same_v<T, __nv_bfloat16> || std::is_same_v<T, int>,
|
||||
"Top K reduction only implemented for int, float, float16 and bfloat16");
|
||||
|
||||
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
|
||||
|
||||
static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
|
||||
static constexpr int kMaxIdx = 65535;
|
||||
TypeCmp compValIdx;
|
||||
|
||||
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) {
|
||||
auto valueBits = cub::Traits<T>::TwiddleIn(
|
||||
reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
|
||||
TypeCmp compactTmp = valueBits;
|
||||
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
|
||||
// Use 65535 minus idx to give higher priority to elements with smaller
|
||||
// indices.
|
||||
return compactTmp;
|
||||
}
|
||||
|
||||
static __host__ __device__ void unpack(T& value, int32_t& index,
|
||||
TypeCmp cmp) {
|
||||
// Since “65535-idx” is always smaller than 65536 and positive, we can
|
||||
// directly use it as the lower 16 bits
|
||||
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));
|
||||
|
||||
auto compactTmp = cmp >> kMoveBits;
|
||||
auto valueBits = cub::Traits<T>::TwiddleOut(
|
||||
reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
|
||||
value = reinterpret_cast<T&>(valueBits);
|
||||
}
|
||||
|
||||
__host__ __device__ TopKRedType() = default;
|
||||
|
||||
__host__ __device__ TopKRedType(T val, int32_t idx)
|
||||
: compValIdx(makeCmpVal(val, idx)) {}
|
||||
|
||||
__host__ __device__ operator TypeCmp() const noexcept { return compValIdx; }
|
||||
|
||||
__device__ inline TypeCmp reduce(
|
||||
cg::thread_block_tile<kWARP_SIZE> const& warp) {
|
||||
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int K_, bool Enable_>
|
||||
struct TopKIdx {
|
||||
// by default, empty
|
||||
};
|
||||
|
||||
template <int K_>
|
||||
struct TopKIdx<K_, true> {
|
||||
static constexpr int K = K_;
|
||||
int32_t val[K];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define TOPK_SWAP(I, J) \
|
||||
{ \
|
||||
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
|
||||
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
|
||||
topK[I].compValIdx = pairMax; \
|
||||
topK[J].compValIdx = pairMin; \
|
||||
}
|
||||
|
||||
template <int N, typename RedType>
|
||||
struct Sort;
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<1, RedType> {
|
||||
static __device__ void run(RedType* topK) {}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<2, RedType> {
|
||||
static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); }
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<3, RedType> {
|
||||
static __device__ void run(RedType* topK) {
|
||||
TOPK_SWAP(0, 1);
|
||||
TOPK_SWAP(1, 2);
|
||||
TOPK_SWAP(0, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<4, RedType> {
|
||||
static __device__ void run(RedType* topK) {
|
||||
TOPK_SWAP(0, 2);
|
||||
TOPK_SWAP(1, 3);
|
||||
TOPK_SWAP(0, 1);
|
||||
TOPK_SWAP(2, 3);
|
||||
TOPK_SWAP(1, 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int K, typename Type>
|
||||
__forceinline__ __device__ void reduceTopK(
|
||||
cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
|
||||
int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue,
|
||||
int actualK = K) {
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
|
||||
using RedType = TopKRedType<Type>;
|
||||
RedType topK{value, idx};
|
||||
typename RedType::TypeCmp packedMax{};
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < actualK; ++kk) {
|
||||
topK =
|
||||
kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK;
|
||||
// get the next largest value
|
||||
packedMax = topK.reduce(warp);
|
||||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||||
}
|
||||
};
|
||||
|
||||
template <int K, typename Type, int N, bool IsSorted = false>
|
||||
__device__ void reduceTopKFunc(cg::thread_block_tile<kWARP_SIZE> const& warp,
|
||||
Type (&out)[K], int32_t (&outIdx)[K],
|
||||
Type (&value)[N], int32_t (&idx)[N],
|
||||
Type minValue, int actualK = K) {
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
|
||||
static_assert(N > 0, "Top K must have N > 0");
|
||||
static_assert(N < 5,
|
||||
"Only support candidates number less than or equal to 128");
|
||||
using RedType = TopKRedType<Type>;
|
||||
RedType topK[N];
|
||||
#pragma unroll
|
||||
for (int nn = 0; nn < N; ++nn) {
|
||||
topK[nn] = RedType{value[nn], idx[nn]};
|
||||
}
|
||||
|
||||
if constexpr (!IsSorted) {
|
||||
Sort<N, RedType>::run(topK);
|
||||
}
|
||||
typename RedType::TypeCmp packedMax{};
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < actualK; ++kk) {
|
||||
bool update = kk > 0 && packedMax == topK[0].compValIdx;
|
||||
#pragma unroll
|
||||
for (int nn = 0; nn < N; ++nn) {
|
||||
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]}
|
||||
: update ? topK[nn + 1]
|
||||
: topK[nn];
|
||||
}
|
||||
// get the next largest value
|
||||
packedMax = topK[0].reduce(warp);
|
||||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||||
}
|
||||
};
|
||||
|
||||
template <int K, typename Type, int N>
|
||||
__forceinline__ __device__ void reduceTopK(
|
||||
cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K],
|
||||
int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N],
|
||||
Type const minValue, int actualK = K) {
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
|
||||
static_assert(N > 0, "Top K must have N > 0");
|
||||
static_assert(
|
||||
N <= 16,
|
||||
"Only support candidates number less than or equal to 16*32=512");
|
||||
static_assert(N <= 4 || N % 4 == 0,
|
||||
"Only support candidates number is a multiple of 4*32=128 or "
|
||||
"less than or equal to 4");
|
||||
using RedType = TopKRedType<Type>;
|
||||
|
||||
if constexpr (N <= 4) {
|
||||
reduceTopKFunc<K, Type, N>(warp, out, outIdx, value, idx, minValue,
|
||||
actualK);
|
||||
} else {
|
||||
constexpr int numLoops = N / 4;
|
||||
constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1;
|
||||
|
||||
Type topKBufferValue[numResults];
|
||||
int32_t topKBufferIdx[numResults];
|
||||
int32_t laneIdx = threadIdx.x % kWARP_SIZE;
|
||||
|
||||
for (int ii = 0; ii < numResults; ++ii) {
|
||||
topKBufferValue[ii] = minValue;
|
||||
topKBufferIdx[ii] = ii * kWARP_SIZE - 1;
|
||||
}
|
||||
for (int loop = 0; loop < numLoops; ++loop) {
|
||||
int start = loop * 4;
|
||||
Type topKValue[K];
|
||||
int32_t topKIdx[K];
|
||||
Type inValue[4];
|
||||
int32_t inIdx[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
inValue[i] = value[start + i];
|
||||
inIdx[i] = idx[start + i];
|
||||
}
|
||||
reduceTopKFunc<K, Type, 4>(warp, topKValue, topKIdx, inValue, inIdx,
|
||||
minValue, actualK);
|
||||
int inOffset = laneIdx % K;
|
||||
if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) {
|
||||
topKBufferValue[0] = topKValue[inOffset];
|
||||
topKBufferIdx[0] = topKIdx[inOffset];
|
||||
}
|
||||
if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) {
|
||||
topKBufferValue[1] = topKValue[inOffset];
|
||||
topKBufferIdx[1] = topKIdx[inOffset];
|
||||
}
|
||||
}
|
||||
|
||||
reduceTopKFunc<K, Type, numResults>(warp, out, outIdx, topKBufferValue,
|
||||
topKBufferIdx, minValue, actualK);
|
||||
}
|
||||
};
|
||||
|
||||
#undef TOPK_SWAP
|
||||
|
||||
} // namespace reduce_topk
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
@@ -8,6 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.batch_invariant as batch_invariant
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
VllmConfig,
|
||||
@@ -27,11 +28,17 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
)
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n_hidden", [1024, 2048])
|
||||
@pytest.mark.parametrize("n_expert", [16])
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"n_expert,topk,num_expert_group,topk_group",
|
||||
[
|
||||
(16, 2, 8, 2),
|
||||
(128, 2, 8, 2),
|
||||
(256, 8, 8, 4),
|
||||
(384, 8, 1, 1),
|
||||
(512, 22, 1, 1),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("num_expert_group", [8])
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
|
||||
@@ -42,9 +49,9 @@ def test_grouped_topk(
|
||||
n_hidden: int,
|
||||
n_expert: int,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
renormalize: bool,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
input_dtype: torch.dtype,
|
||||
@@ -62,6 +69,7 @@ def test_grouped_topk(
|
||||
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True)
|
||||
grouped_topk = GroupedTopk(
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
@@ -89,8 +97,7 @@ def test_grouped_topk(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user