diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index eaebf4e35..6a4dad3be 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -1,6 +1,6 @@ /* * Adapted from - * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu * Copyright (c) 2025, The vLLM team. * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 @@ -17,8 +17,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "moeTopKFuncs.cuh" #include #include +#include #include #include #include @@ -30,7 +32,17 @@ namespace vllm { namespace moe { constexpr unsigned FULL_WARP_MASK = 0xffffffff; -constexpr int32_t WARP_SIZE = 32; +static constexpr int WARP_SIZE = 32; +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int MaxNumExpertsUnit = 128; +static constexpr int NumTopGroupScores = 2; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; +static constexpr int MaxNumTopGroups = 4; namespace warp_topk { @@ -657,76 +669,335 @@ __global__ void grouped_topk_fused_kernel( #endif } -template +template +__global__ void grouped_topk_fused_small_expert_count_kernel( + T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias, + int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, + int64_t const topk, int64_t const numExperts, + int64_t const numExpertsPerGroup, bool const renormalize, + double const routedScalingFactor) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts]; + // number of expert groups is bounded by number of warps + int constexpr NumWarps = MaxNumExperts / WARP_SIZE; + __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WARP_SIZE; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); + + if constexpr (UseGroups) { + if (warpIdx >= numGroup) { + return; + } + } + // note that for invalid scores, we simply use a negative value: + // they work well even with the compacted format used in topK, and + // sigmoid / bias activated scores cannot be negative + const float invalidScoreFloat = float{-INFINITY}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < numExperts; + if constexpr (UseGroups) { + threadExpert = warpIdx * numExpertsPerGroup + laneIdx; + expertSelected = laneIdx < numExpertsPerGroup; + } + + auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert; + auto biasVal = expertSelected ? static_cast(routingBias[threadExpert]) + : invalidScoreFloat; + topkValues += blockIdx.x * topk; + topkIndices += blockIdx.x * topk; + + // get our assigned thread score; each warp represents one expert group + float score = + expertSelected ? static_cast(scores[scoreIdx]) : invalidScoreFloat; + auto scoreSigmoid = apply_scoring(score); + // write the sigmoid score to shared for later use + if (expertSelected) { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of numGroup + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[MaxNumTopExperts]; // bound of topk + int32_t topExperts[MaxNumTopExperts]; + + if constexpr (UseGroups) { + reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, + threadExpert, + /* minValue */ invalidScoreFloat); + + // get the final group score and write it to shared + if (warp.thread_rank() == 0) { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + if constexpr (UseGroups) { + if (warpIdx == 0) { + // a single warp performs the selection of top groups, and goes on to + // select the final experts + float groupScore = + laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat; + + reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + // final expert selection: get relevant indexes and scores from shared +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup + auto groupIdx = topGroupIdx[ii]; + expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx; + + expertScoreGroup[ii] = (ii < topkGroup) && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, + expertIdxGroup, /* minValue */ invalidScoreFloat, + topk); + } + } else if constexpr (MaxNumExperts > MaxNumExpertsUnit) { + // without groups, and the expert number is larger than MaxNumExpertsUnit, + // we need to use multiple warps to calculate the intermediate topk results + + int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + __shared__ float + __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t + __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WARP_SIZE * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = offset + expertIdx < numExperts + ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, + expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + + if (laneIdx < topk) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = + topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = + topExperts[laneIdx]; + } else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = + invalidScoreFloat; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = + MaxNumExperts - 1; + } + } + __syncthreads(); + if (warpIdx == 0) { + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; + i += WARP_SIZE) { + int ii = i / WARP_SIZE; + if (i < NumInterTopK) { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } else { + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = MaxNumExperts - 1; + } + } + reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore, + intermediateExpert, + /* minValue */ invalidScoreFloat, topk); + } + } else { + // without groups, and the expert number is smaller than MaxNumExpertsUnit + // each thread just takes `MaxNumTopGroups` experts + if (warpIdx == 0) { +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WARP_SIZE + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx] + : invalidScoreFloat; + } + reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, + expertIdxGroup, + /* minValue */ invalidScoreFloat, topk); + } + } + + if (warpIdx == 0) { + // determine our lane's expert index and write to output + int32_t expertIdx = + laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1; + float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F; + float finalScore = static_cast(scoreNorm * routedScalingFactor); + // norm the value + if (renormalize) { + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + finalScore /= (redNorm + 1e-20); + } + // store the topk scores and experts to output + if (laneIdx < topk) { + topkValues[laneIdx] = finalScore; + topkIndices[laneIdx] = expertIdx; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, int64_t const n_group, int64_t const topk_group, int64_t const topk, bool const renormalize, double const routed_scaling_factor, - int const scoring_func, bool enable_pdl = false, - cudaStream_t const stream = 0) { + bool enable_pdl = false, cudaStream_t const stream = 0) { cudaLaunchConfig_t config; - // One block per token; one warp per group. - config.gridDim = static_cast(num_tokens); - config.blockDim = static_cast(n_group) * WARP_SIZE; - // Dynamic shared memory: WarpSelect staging + per-group topk buffers. - int32_t const num_warps = static_cast(n_group); - size_t const val_bytes = - static_cast(num_warps) * WARP_SIZE * sizeof(T); - size_t const val_bytes_aligned = - warp_topk::round_up_to_multiple_of<256>(val_bytes); - size_t const idx_bytes = - static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); - size_t const internal_bytes = val_bytes_aligned + idx_bytes; - size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(T); - config.dynamicSmemBytes = internal_bytes + extra_bytes; config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - auto const sf = static_cast(scoring_func); - switch (sf) { - case SCORING_NONE: { - auto* kernel_instance = - &grouped_topk_fused_kernel; - cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, - topk_indices, bias, num_tokens, num_experts, n_group, - topk_group, topk, renormalize, routed_scaling_factor); - return; + + // Check if we can use the optimized + // grouped_topk_fused_small_expert_count_kernel + bool const is_single_group = + (n_group == 1) && (topk_group == 1) && + (num_experts <= MaxSupportedExpertCount) && + (topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts); + + int64_t const experts_per_group = num_experts / n_group; + bool const is_multi_group = + (n_group > 1) && (num_experts <= NumDeepseekExperts) && + (experts_per_group <= WARP_SIZE) && + (experts_per_group * topk_group <= MaxNumExpertsUnit) && + (topk <= DefaultMaxNumTopExperts) && (topk_group <= MaxNumTopGroups); + + if (is_single_group || is_multi_group) { + auto* kernel_instance = + &grouped_topk_fused_small_expert_count_kernel; + int num_threads = NumDeepseekExperts; + if (is_single_group) { + // Special case for Nemotron, which selects top 22 from 512 experts, and 1 + // group only. + if (num_experts == NumNemotronExperts && n_group == 1 && + topk == MaxSupportedTopExperts) { + kernel_instance = &grouped_topk_fused_small_expert_count_kernel< + T, BiasT, IdxT, SF, NumNemotronExperts, false, + MaxSupportedTopExperts>; + num_threads = NumNemotronExperts; + } else if (num_experts > NumKimiK2Experts && + num_experts <= MaxSupportedExpertCount) { + kernel_instance = &grouped_topk_fused_small_expert_count_kernel< + T, BiasT, IdxT, SF, MaxSupportedExpertCount, false>; + num_threads = MaxSupportedExpertCount; + } else if (num_experts > MaxNumExpertsUnit && + num_experts <= NumKimiK2Experts) { + kernel_instance = &grouped_topk_fused_small_expert_count_kernel< + T, BiasT, IdxT, SF, NumKimiK2Experts, false>; + num_threads = NumKimiK2Experts; + } else { + kernel_instance = &grouped_topk_fused_small_expert_count_kernel< + T, BiasT, IdxT, SF, MaxNumExpertsUnit, false>; + num_threads = MaxNumExpertsUnit; + } } - case SCORING_SIGMOID: { - auto* kernel_instance = - &grouped_topk_fused_kernel; - cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, - topk_indices, bias, num_tokens, num_experts, n_group, - topk_group, topk, renormalize, routed_scaling_factor); - return; - } - default: - // should be guarded by higher level checks. - TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + config.gridDim = num_tokens; + config.blockDim = num_threads; + config.dynamicSmemBytes = 0; + cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, + topk_indices, bias, num_tokens, n_group, topk_group, + topk, num_experts, num_experts / n_group, renormalize, + routed_scaling_factor); + } else { + auto* kernel_instance = &grouped_topk_fused_kernel; + // One block per token; one warp per group. + config.gridDim = static_cast(num_tokens); + config.blockDim = static_cast(n_group) * WARP_SIZE; + // Dynamic shared memory: WarpSelect staging + per-group topk buffers. + int32_t const num_warps = static_cast(n_group); + size_t const val_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(T); + size_t const val_bytes_aligned = + warp_topk::round_up_to_multiple_of<256>(val_bytes); + size_t const idx_bytes = + static_cast(num_warps) * WARP_SIZE * sizeof(int32_t); + size_t const internal_bytes = val_bytes_aligned + idx_bytes; + size_t const extra_bytes = 16 + static_cast(n_group) * sizeof(T); + config.dynamicSmemBytes = internal_bytes + extra_bytes; + cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values, + topk_indices, bias, num_tokens, num_experts, n_group, + topk_group, topk, renormalize, routed_scaling_factor); } } -#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \ - template void invokeNoAuxTc( \ +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \ + template void invokeNoAuxTc( \ T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \ int64_t const num_tokens, int64_t const num_experts, \ int64_t const n_group, int64_t const topk_group, int64_t const topk, \ bool const renormalize, double const routed_scaling_factor, \ - int const scoring_func, bool enable_pdl, cudaStream_t const stream); + bool enable_pdl, cudaStream_t const stream); -INSTANTIATE_NOAUX_TC(float, float, int32_t); -INSTANTIATE_NOAUX_TC(float, half, int32_t); -INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t); -INSTANTIATE_NOAUX_TC(half, float, int32_t); -INSTANTIATE_NOAUX_TC(half, half, int32_t); -INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t); -INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t); -INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t); -INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_SIGMOID); +INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_NONE); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_NONE); } // end namespace moe } // namespace vllm @@ -762,46 +1033,53 @@ std::tuple grouped_topk( {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device()); + auto const sf = static_cast(scoring_func); -#define LAUNCH_KERNEL(T, IdxT) \ - do { \ - switch (bias_type) { \ - case torch::kFloat16: \ - vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(topk_values.mutable_data_ptr()), \ - reinterpret_cast(topk_indices.mutable_data_ptr()), \ - reinterpret_cast(bias.data_ptr()), num_tokens, \ - num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, static_cast(scoring_func), false, \ - stream); \ - break; \ - case torch::kFloat32: \ - vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(topk_values.mutable_data_ptr()), \ - reinterpret_cast(topk_indices.mutable_data_ptr()), \ - reinterpret_cast(bias.data_ptr()), num_tokens, \ - num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, static_cast(scoring_func), false, \ - stream); \ - break; \ - case torch::kBFloat16: \ - vllm::moe::invokeNoAuxTc( \ - reinterpret_cast(scores.mutable_data_ptr()), \ - reinterpret_cast(topk_values.mutable_data_ptr()), \ - reinterpret_cast(topk_indices.mutable_data_ptr()), \ - reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \ - num_tokens, num_experts, n_group, topk_group, topk, renormalize, \ - routed_scaling_factor, static_cast(scoring_func), false, \ - stream); \ - break; \ - default: \ - throw std::invalid_argument( \ - "Invalid bias dtype, only supports float16, float32, and " \ - "bfloat16"); \ - break; \ - } \ +#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \ + do { \ + switch (sf) { \ + case vllm::moe::SCORING_NONE: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + reinterpret_cast(bias.data_ptr()), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, false, stream); \ + break; \ + case vllm::moe::SCORING_SIGMOID: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + reinterpret_cast(bias.data_ptr()), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, false, stream); \ + break; \ + default: \ + throw std::invalid_argument("Unsupported scoring_func"); \ + break; \ + } \ + } while (0) + +#define LAUNCH_KERNEL(T, IdxT) \ + do { \ + switch (bias_type) { \ + case torch::kFloat16: \ + LAUNCH_KERNEL_SF(T, half, IdxT); \ + break; \ + case torch::kFloat32: \ + LAUNCH_KERNEL_SF(T, float, IdxT); \ + break; \ + case torch::kBFloat16: \ + LAUNCH_KERNEL_SF(T, __nv_bfloat16, IdxT); \ + break; \ + default: \ + throw std::invalid_argument( \ + "Invalid bias dtype, only supports float16, float32, and " \ + "bfloat16"); \ + break; \ + } \ } while (0) switch (data_type) { @@ -824,5 +1102,6 @@ std::tuple grouped_topk( break; } #undef LAUNCH_KERNEL +#undef LAUNCH_KERNEL_SF return {topk_values, topk_indices}; } diff --git a/csrc/moe/moeTopKFuncs.cuh b/csrc/moe/moeTopKFuncs.cuh new file mode 100644 index 000000000..70e21cf87 --- /dev/null +++ b/csrc/moe/moeTopKFuncs.cuh @@ -0,0 +1,257 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh + * Copyright (c) 2026, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION. All rights + * reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace vllm { +namespace moe { +namespace reduce_topk { +namespace cg = cooperative_groups; +static constexpr int kWARP_SIZE = 32; + +template +struct TopKRedType { + using T = T_; + static_assert( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v, + "Top K reduction only implemented for int, float, float16 and bfloat16"); + + using TypeCmp = std::conditional_t; + using IdxT = std::conditional_t; + + static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16; + static constexpr int kMaxIdx = 65535; + TypeCmp compValIdx; + + static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0) { + auto valueBits = cub::Traits::TwiddleIn( + reinterpret_cast::UnsignedBits&>(val)); + TypeCmp compactTmp = valueBits; + compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx)); + // Use 65535 minus idx to give higher priority to elements with smaller + // indices. + return compactTmp; + } + + static __host__ __device__ void unpack(T& value, int32_t& index, + TypeCmp cmp) { + // Since “65535-idx” is always smaller than 65536 and positive, we can + // directly use it as the lower 16 bits + index = kMaxIdx - static_cast((cmp & 0xFFFF)); + + auto compactTmp = cmp >> kMoveBits; + auto valueBits = cub::Traits::TwiddleOut( + reinterpret_cast::UnsignedBits&>(compactTmp)); + value = reinterpret_cast(valueBits); + } + + __host__ __device__ TopKRedType() = default; + + __host__ __device__ TopKRedType(T val, int32_t idx) + : compValIdx(makeCmpVal(val, idx)) {} + + __host__ __device__ operator TypeCmp() const noexcept { return compValIdx; } + + __device__ inline TypeCmp reduce( + cg::thread_block_tile const& warp) { + return cg::reduce(warp, compValIdx, cg::greater{}); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TopKIdx { + // by default, empty +}; + +template +struct TopKIdx { + static constexpr int K = K_; + int32_t val[K]; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \ + auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \ + topK[I].compValIdx = pairMax; \ + topK[J].compValIdx = pairMin; \ + } + +template +struct Sort; + +template +struct Sort<1, RedType> { + static __device__ void run(RedType* topK) {} +}; + +template +struct Sort<2, RedType> { + static __device__ void run(RedType* topK) { TOPK_SWAP(0, 1); } +}; + +template +struct Sort<3, RedType> { + static __device__ void run(RedType* topK) { + TOPK_SWAP(0, 1); + TOPK_SWAP(1, 2); + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<4, RedType> { + static __device__ void run(RedType* topK) { + TOPK_SWAP(0, 2); + TOPK_SWAP(1, 3); + TOPK_SWAP(0, 1); + TOPK_SWAP(2, 3); + TOPK_SWAP(1, 2); + } +}; + +template +__forceinline__ __device__ void reduceTopK( + cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue, + int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + using RedType = TopKRedType; + RedType topK{value, idx}; + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < actualK; ++kk) { + topK = + kk > 0 && packedMax == topK.compValIdx ? RedType{minValue, idx} : topK; + // get the next largest value + packedMax = topK.reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__device__ void reduceTopKFunc(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], + Type (&value)[N], int32_t (&idx)[N], + Type minValue, int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N < 5, + "Only support candidates number less than or equal to 128"); + using RedType = TopKRedType; + RedType topK[N]; +#pragma unroll + for (int nn = 0; nn < N; ++nn) { + topK[nn] = RedType{value[nn], idx[nn]}; + } + + if constexpr (!IsSorted) { + Sort::run(topK); + } + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < actualK; ++kk) { + bool update = kk > 0 && packedMax == topK[0].compValIdx; +#pragma unroll + for (int nn = 0; nn < N; ++nn) { + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} + : update ? topK[nn + 1] + : topK[nn]; + } + // get the next largest value + packedMax = topK[0].reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__forceinline__ __device__ void reduceTopK( + cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], + Type const minValue, int actualK = K) { + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert( + N <= 16, + "Only support candidates number less than or equal to 16*32=512"); + static_assert(N <= 4 || N % 4 == 0, + "Only support candidates number is a multiple of 4*32=128 or " + "less than or equal to 4"); + using RedType = TopKRedType; + + if constexpr (N <= 4) { + reduceTopKFunc(warp, out, outIdx, value, idx, minValue, + actualK); + } else { + constexpr int numLoops = N / 4; + constexpr int numResults = (numLoops * K - 1) / kWARP_SIZE + 1; + + Type topKBufferValue[numResults]; + int32_t topKBufferIdx[numResults]; + int32_t laneIdx = threadIdx.x % kWARP_SIZE; + + for (int ii = 0; ii < numResults; ++ii) { + topKBufferValue[ii] = minValue; + topKBufferIdx[ii] = ii * kWARP_SIZE - 1; + } + for (int loop = 0; loop < numLoops; ++loop) { + int start = loop * 4; + Type topKValue[K]; + int32_t topKIdx[K]; + Type inValue[4]; + int32_t inIdx[4]; + for (int i = 0; i < 4; ++i) { + inValue[i] = value[start + i]; + inIdx[i] = idx[start + i]; + } + reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, + minValue, actualK); + int inOffset = laneIdx % K; + if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { + topKBufferValue[0] = topKValue[inOffset]; + topKBufferIdx[0] = topKIdx[inOffset]; + } + if (loop == numLoops - 1 && (laneIdx < (numLoops * K - kWARP_SIZE))) { + topKBufferValue[1] = topKValue[inOffset]; + topKBufferIdx[1] = topKIdx[inOffset]; + } + } + + reduceTopKFunc(warp, out, outIdx, topKBufferValue, + topKBufferIdx, minValue, actualK); + } +}; + +#undef TOPK_SWAP + +} // namespace reduce_topk +} // namespace moe +} // namespace vllm diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index 2a974206d..70c7285ac 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -8,6 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`. import pytest import torch +import vllm.model_executor.layers.batch_invariant as batch_invariant from vllm.config import ( CompilationConfig, VllmConfig, @@ -27,11 +28,17 @@ from vllm.utils.torch_utils import set_random_seed ) @pytest.mark.parametrize("n_token", [1, 33, 64]) @pytest.mark.parametrize("n_hidden", [1024, 2048]) -@pytest.mark.parametrize("n_expert", [16]) -@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize( + "n_expert,topk,num_expert_group,topk_group", + [ + (16, 2, 8, 2), + (128, 2, 8, 2), + (256, 8, 8, 4), + (384, 8, 1, 1), + (512, 22, 1, 1), + ], +) @pytest.mark.parametrize("renormalize", [True, False]) -@pytest.mark.parametrize("num_expert_group", [8]) -@pytest.mark.parametrize("topk_group", [2]) @pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) @@ -42,9 +49,9 @@ def test_grouped_topk( n_hidden: int, n_expert: int, topk: int, - renormalize: bool, num_expert_group: int, topk_group: int, + renormalize: bool, scoring_func: str, routed_scaling_factor: float, input_dtype: torch.dtype, @@ -62,6 +69,7 @@ def test_grouped_topk( with set_current_vllm_config(vllm_config), monkeypatch.context() as m: m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") + m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True) grouped_topk = GroupedTopk( topk=topk, renormalize=renormalize, @@ -89,8 +97,7 @@ def test_grouped_topk( e_score_correction_bias=e_score_correction_bias, ) - if renormalize: - torch.testing.assert_close( - baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0 - ) + torch.testing.assert_close( + baseline_topk_weights, test_topk_weights, atol=2e-2, rtol=0 + ) torch.testing.assert_close(baseline_topk_ids, test_topk_ids, atol=0, rtol=0)