[Kernel] Optimize grouped topk kernel (#34206)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v1.3.0rc2/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
@@ -17,8 +17,10 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "moeTopKFuncs.cuh"
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <cmath>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda/std/limits>
|
||||
@@ -30,7 +32,17 @@ namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int NumNemotronExperts = 512;
|
||||
static constexpr int NumKimiK2Experts = 384;
|
||||
static constexpr int NumDeepseekExperts = 256;
|
||||
static constexpr int MaxSupportedExpertCount =
|
||||
std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
|
||||
static constexpr int MaxNumExpertsUnit = 128;
|
||||
static constexpr int NumTopGroupScores = 2;
|
||||
static constexpr int DefaultMaxNumTopExperts = 8;
|
||||
static constexpr int MaxSupportedTopExperts = 22;
|
||||
static constexpr int MaxNumTopGroups = 4;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
@@ -657,76 +669,335 @@ __global__ void grouped_topk_fused_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT>
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
|
||||
int MaxNumExperts, bool UseGroups,
|
||||
int MaxNumTopExperts = DefaultMaxNumTopExperts>
|
||||
__global__ void grouped_topk_fused_small_expert_count_kernel(
|
||||
T* scores, float* topkValues, IdxT* topkIndices, BiasT const* routingBias,
|
||||
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup,
|
||||
int64_t const topk, int64_t const numExperts,
|
||||
int64_t const numExpertsPerGroup, bool const renormalize,
|
||||
double const routedScalingFactor) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
// declare shared memory structure
|
||||
// number of experts is bounded by number of threads
|
||||
__shared__ float __attribute((aligned(128))) smemScoreSigmoid[MaxNumExperts];
|
||||
__shared__ float __attribute((aligned(128))) smemScoreBias[MaxNumExperts];
|
||||
// number of expert groups is bounded by number of warps
|
||||
int constexpr NumWarps = MaxNumExperts / WARP_SIZE;
|
||||
__shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps];
|
||||
|
||||
// needed for warp reduce
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
// for the final reduction of weight norm, only some lanes need to participate
|
||||
int32_t laneIdx = threadIdx.x % WARP_SIZE;
|
||||
int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
|
||||
|
||||
if constexpr (UseGroups) {
|
||||
if (warpIdx >= numGroup) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// note that for invalid scores, we simply use a negative value:
|
||||
// they work well even with the compacted format used in topK, and
|
||||
// sigmoid / bias activated scores cannot be negative
|
||||
const float invalidScoreFloat = float{-INFINITY};
|
||||
|
||||
// load bias already; each warp represents one expert group
|
||||
auto threadExpert = threadIdx.x;
|
||||
bool expertSelected = threadExpert < numExperts;
|
||||
if constexpr (UseGroups) {
|
||||
threadExpert = warpIdx * numExpertsPerGroup + laneIdx;
|
||||
expertSelected = laneIdx < numExpertsPerGroup;
|
||||
}
|
||||
|
||||
auto scoreIdx = int64_t{blockIdx.x} * int64_t{numExperts} + threadExpert;
|
||||
auto biasVal = expertSelected ? static_cast<float>(routingBias[threadExpert])
|
||||
: invalidScoreFloat;
|
||||
topkValues += blockIdx.x * topk;
|
||||
topkIndices += blockIdx.x * topk;
|
||||
|
||||
// get our assigned thread score; each warp represents one expert group
|
||||
float score =
|
||||
expertSelected ? static_cast<float>(scores[scoreIdx]) : invalidScoreFloat;
|
||||
auto scoreSigmoid = apply_scoring<SF>(score);
|
||||
// write the sigmoid score to shared for later use
|
||||
if (expertSelected) {
|
||||
smemScoreSigmoid[threadExpert] = scoreSigmoid;
|
||||
}
|
||||
|
||||
// get the score with bias
|
||||
// note that with invalid values, because sigmoid is < 1 and bias is -1,
|
||||
// we must get a negative value, which is smaller than any valid value
|
||||
auto scoreBias = float{scoreSigmoid + float{biasVal}};
|
||||
|
||||
if (expertSelected) {
|
||||
smemScoreBias[threadExpert] = scoreBias;
|
||||
}
|
||||
|
||||
// registers for top group score reduction
|
||||
float topExpGroupScores[NumTopGroupScores];
|
||||
[[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores];
|
||||
float topGroups[MaxNumTopGroups]; // bound of numGroup
|
||||
int32_t topGroupIdx[MaxNumTopGroups];
|
||||
float expertScoreGroup[MaxNumTopGroups];
|
||||
int32_t expertIdxGroup[MaxNumTopGroups];
|
||||
float topScores[MaxNumTopExperts]; // bound of topk
|
||||
int32_t topExperts[MaxNumTopExperts];
|
||||
|
||||
if constexpr (UseGroups) {
|
||||
reduce_topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias,
|
||||
threadExpert,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
|
||||
// get the final group score and write it to shared
|
||||
if (warp.thread_rank() == 0) {
|
||||
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
|
||||
smemGroupScores[warpIdx] = groupScore;
|
||||
}
|
||||
}
|
||||
|
||||
// make group scores available to all warps
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (UseGroups) {
|
||||
if (warpIdx == 0) {
|
||||
// a single warp performs the selection of top groups, and goes on to
|
||||
// select the final experts
|
||||
float groupScore =
|
||||
laneIdx < numGroup ? smemGroupScores[laneIdx] : invalidScoreFloat;
|
||||
|
||||
reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
|
||||
/* minValue */ invalidScoreFloat);
|
||||
// final expert selection: get relevant indexes and scores from shared
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup
|
||||
auto groupIdx = topGroupIdx[ii];
|
||||
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;
|
||||
|
||||
expertScoreGroup[ii] = (ii < topkGroup) && expertSelected
|
||||
? smemScoreBias[expertIdxGroup[ii]]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup, /* minValue */ invalidScoreFloat,
|
||||
topk);
|
||||
}
|
||||
} else if constexpr (MaxNumExperts > MaxNumExpertsUnit) {
|
||||
// without groups, and the expert number is larger than MaxNumExpertsUnit,
|
||||
// we need to use multiple warps to calculate the intermediate topk results
|
||||
|
||||
int constexpr NumExpertWarps = (MaxNumExperts - 1) / MaxNumExpertsUnit + 1;
|
||||
int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts;
|
||||
__shared__ float
|
||||
__attribute((aligned(128))) smemInterTopScores[NumInterTopK];
|
||||
__shared__ int32_t
|
||||
__attribute((aligned(128))) smemInterTopExperts[NumInterTopK];
|
||||
if (warpIdx < NumExpertWarps) {
|
||||
int offset = warpIdx * WARP_SIZE * MaxNumTopGroups;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
|
||||
auto expertIdx = ii * WARP_SIZE + laneIdx;
|
||||
expertIdxGroup[ii] = offset + expertIdx;
|
||||
expertScoreGroup[ii] = offset + expertIdx < numExperts
|
||||
? smemScoreBias[offset + expertIdx]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
|
||||
if (laneIdx < topk) {
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
topScores[laneIdx];
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
topExperts[laneIdx];
|
||||
} else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) {
|
||||
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
invalidScoreFloat;
|
||||
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] =
|
||||
MaxNumExperts - 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (warpIdx == 0) {
|
||||
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
|
||||
float intermediateScore[NumInterTopKPerThread];
|
||||
int32_t intermediateExpert[NumInterTopKPerThread];
|
||||
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE;
|
||||
i += WARP_SIZE) {
|
||||
int ii = i / WARP_SIZE;
|
||||
if (i < NumInterTopK) {
|
||||
intermediateScore[ii] = smemInterTopScores[i];
|
||||
intermediateExpert[ii] = smemInterTopExperts[i];
|
||||
} else {
|
||||
intermediateScore[ii] = invalidScoreFloat;
|
||||
intermediateExpert[ii] = MaxNumExperts - 1;
|
||||
}
|
||||
}
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, intermediateScore,
|
||||
intermediateExpert,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
}
|
||||
} else {
|
||||
// without groups, and the expert number is smaller than MaxNumExpertsUnit
|
||||
// each thread just takes `MaxNumTopGroups` experts
|
||||
if (warpIdx == 0) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < MaxNumTopGroups; ++ii) {
|
||||
auto expertIdx = ii * WARP_SIZE + laneIdx;
|
||||
expertIdxGroup[ii] = expertIdx;
|
||||
expertScoreGroup[ii] = expertIdx < numExperts ? smemScoreBias[expertIdx]
|
||||
: invalidScoreFloat;
|
||||
}
|
||||
reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
|
||||
expertIdxGroup,
|
||||
/* minValue */ invalidScoreFloat, topk);
|
||||
}
|
||||
}
|
||||
|
||||
if (warpIdx == 0) {
|
||||
// determine our lane's expert index and write to output
|
||||
int32_t expertIdx =
|
||||
laneIdx < topk ? topExperts[laneIdx] : MaxNumExperts - 1;
|
||||
float scoreNorm = laneIdx < topk ? smemScoreSigmoid[expertIdx] : 0.F;
|
||||
float finalScore = static_cast<float>(scoreNorm * routedScalingFactor);
|
||||
// norm the value
|
||||
if (renormalize) {
|
||||
auto redNorm = cg::reduce(warp, scoreNorm, cg::plus<float>{});
|
||||
finalScore /= (redNorm + 1e-20);
|
||||
}
|
||||
// store the topk scores and experts to output
|
||||
if (laneIdx < topk) {
|
||||
topkValues[laneIdx] = finalScore;
|
||||
topkIndices[laneIdx] = expertIdx;
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
|
||||
void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices,
|
||||
BiasT const* bias, int64_t const num_tokens,
|
||||
int64_t const num_experts, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk,
|
||||
bool const renormalize, double const routed_scaling_factor,
|
||||
int const scoring_func, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
bool enable_pdl = false, cudaStream_t const stream = 0) {
|
||||
cudaLaunchConfig_t config;
|
||||
// One block per token; one warp per group.
|
||||
config.gridDim = static_cast<uint32_t>(num_tokens);
|
||||
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
|
||||
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
|
||||
int32_t const num_warps = static_cast<int32_t>(n_group);
|
||||
size_t const val_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
|
||||
size_t const val_bytes_aligned =
|
||||
warp_topk::round_up_to_multiple_of<256>(val_bytes);
|
||||
size_t const idx_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
|
||||
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
|
||||
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
|
||||
config.dynamicSmemBytes = internal_bytes + extra_bytes;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
auto const sf = static_cast<ScoringFunc>(scoring_func);
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_NONE>;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
return;
|
||||
|
||||
// Check if we can use the optimized
|
||||
// grouped_topk_fused_small_expert_count_kernel
|
||||
bool const is_single_group =
|
||||
(n_group == 1) && (topk_group == 1) &&
|
||||
(num_experts <= MaxSupportedExpertCount) &&
|
||||
(topk <= DefaultMaxNumTopExperts || topk == MaxSupportedTopExperts);
|
||||
|
||||
int64_t const experts_per_group = num_experts / n_group;
|
||||
bool const is_multi_group =
|
||||
(n_group > 1) && (num_experts <= NumDeepseekExperts) &&
|
||||
(experts_per_group <= WARP_SIZE) &&
|
||||
(experts_per_group * topk_group <= MaxNumExpertsUnit) &&
|
||||
(topk <= DefaultMaxNumTopExperts) && (topk_group <= MaxNumTopGroups);
|
||||
|
||||
if (is_single_group || is_multi_group) {
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_small_expert_count_kernel<T, BiasT, IdxT, SF,
|
||||
NumDeepseekExperts, true>;
|
||||
int num_threads = NumDeepseekExperts;
|
||||
if (is_single_group) {
|
||||
// Special case for Nemotron, which selects top 22 from 512 experts, and 1
|
||||
// group only.
|
||||
if (num_experts == NumNemotronExperts && n_group == 1 &&
|
||||
topk == MaxSupportedTopExperts) {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, NumNemotronExperts, false,
|
||||
MaxSupportedTopExperts>;
|
||||
num_threads = NumNemotronExperts;
|
||||
} else if (num_experts > NumKimiK2Experts &&
|
||||
num_experts <= MaxSupportedExpertCount) {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, MaxSupportedExpertCount, false>;
|
||||
num_threads = MaxSupportedExpertCount;
|
||||
} else if (num_experts > MaxNumExpertsUnit &&
|
||||
num_experts <= NumKimiK2Experts) {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, NumKimiK2Experts, false>;
|
||||
num_threads = NumKimiK2Experts;
|
||||
} else {
|
||||
kernel_instance = &grouped_topk_fused_small_expert_count_kernel<
|
||||
T, BiasT, IdxT, SF, MaxNumExpertsUnit, false>;
|
||||
num_threads = MaxNumExpertsUnit;
|
||||
}
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_SIGMOID>;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
return;
|
||||
}
|
||||
default:
|
||||
// should be guarded by higher level checks.
|
||||
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||
config.gridDim = num_tokens;
|
||||
config.blockDim = num_threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, n_group, topk_group,
|
||||
topk, num_experts, num_experts / n_group, renormalize,
|
||||
routed_scaling_factor);
|
||||
} else {
|
||||
auto* kernel_instance = &grouped_topk_fused_kernel<T, BiasT, IdxT, SF>;
|
||||
// One block per token; one warp per group.
|
||||
config.gridDim = static_cast<uint32_t>(num_tokens);
|
||||
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
|
||||
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
|
||||
int32_t const num_warps = static_cast<int32_t>(n_group);
|
||||
size_t const val_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
|
||||
size_t const val_bytes_aligned =
|
||||
warp_topk::round_up_to_multiple_of<256>(val_bytes);
|
||||
size_t const idx_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
|
||||
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
|
||||
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
|
||||
config.dynamicSmemBytes = internal_bytes + extra_bytes;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT>( \
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT, SF) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT, SF>( \
|
||||
T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \
|
||||
int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
|
||||
bool enable_pdl, cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_SIGMOID);
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t, SCORING_NONE);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t, SCORING_NONE);
|
||||
} // end namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
@@ -762,46 +1033,53 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
|
||||
auto const sf = static_cast<vllm::moe::ScoringFunc>(scoring_func);
|
||||
|
||||
#define LAUNCH_KERNEL(T, IdxT) \
|
||||
do { \
|
||||
switch (bias_type) { \
|
||||
case torch::kFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
case torch::kFloat32: \
|
||||
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
case torch::kBFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
|
||||
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument( \
|
||||
"Invalid bias dtype, only supports float16, float32, and " \
|
||||
"bfloat16"); \
|
||||
break; \
|
||||
} \
|
||||
#define LAUNCH_KERNEL_SF(T, BiasT, IdxT) \
|
||||
do { \
|
||||
switch (sf) { \
|
||||
case vllm::moe::SCORING_NONE: \
|
||||
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_NONE>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, false, stream); \
|
||||
break; \
|
||||
case vllm::moe::SCORING_SIGMOID: \
|
||||
vllm::moe::invokeNoAuxTc<T, BiasT, IdxT, vllm::moe::SCORING_SIGMOID>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<BiasT const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, false, stream); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument("Unsupported scoring_func"); \
|
||||
break; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LAUNCH_KERNEL(T, IdxT) \
|
||||
do { \
|
||||
switch (bias_type) { \
|
||||
case torch::kFloat16: \
|
||||
LAUNCH_KERNEL_SF(T, half, IdxT); \
|
||||
break; \
|
||||
case torch::kFloat32: \
|
||||
LAUNCH_KERNEL_SF(T, float, IdxT); \
|
||||
break; \
|
||||
case torch::kBFloat16: \
|
||||
LAUNCH_KERNEL_SF(T, __nv_bfloat16, IdxT); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument( \
|
||||
"Invalid bias dtype, only supports float16, float32, and " \
|
||||
"bfloat16"); \
|
||||
break; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
switch (data_type) {
|
||||
@@ -824,5 +1102,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
break;
|
||||
}
|
||||
#undef LAUNCH_KERNEL
|
||||
#undef LAUNCH_KERNEL_SF
|
||||
return {topk_values, topk_indices};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user