[Perf] Optimize grouped topk kernel, 1.2%~2% E2E Throughput improvement (#32058)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -31,8 +31,6 @@ namespace moe {
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
@@ -65,14 +63,6 @@ __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
@@ -267,6 +257,15 @@ class WarpSort {
|
||||
}
|
||||
}
|
||||
|
||||
// Accessors for per-lane selected value/index.
|
||||
// NOTE: For the common case `capacity == WARP_SIZE`, `max_arr_len_ == 1`
|
||||
// and callers should use `i == 0`.
|
||||
__device__ __forceinline__ idxT get_idx(int i = 0) const {
|
||||
return idx_arr_[i];
|
||||
}
|
||||
|
||||
__device__ __forceinline__ T get_val(int i = 0) const { return val_arr_[i]; }
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
@@ -285,6 +284,7 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_idx_(0),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
@@ -346,9 +346,6 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -503,255 +500,186 @@ __device__ void topk_with_k2(T* output, T const* input, BiasT const* bias,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, ScoringFunc SF>
|
||||
__global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
// bias is per expert group, offset to current group
|
||||
int32_t group_id = case_id % n_group;
|
||||
BiasT const* group_bias = bias + group_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2<T, BiasT, SF>(output, input, group_bias, tile, lane_id,
|
||||
num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
|
||||
int NGroup = -1>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
|
||||
BiasT const* bias, int64_t const num_tokens, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool renormalize,
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
|
||||
__global__ void grouped_topk_fused_kernel(
|
||||
T* scores, float* topk_values, IdxT* topk_indices, BiasT const* bias,
|
||||
int64_t const num_tokens, int64_t const num_experts, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
int32_t const token_id = static_cast<int32_t>(blockIdx.x);
|
||||
if (token_id >= num_tokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr bool kUseStaticNGroup = (NGroup > 0);
|
||||
// use int32 to avoid implicit conversion
|
||||
int32_t const n_group_i32 =
|
||||
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
|
||||
int32_t const warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t const lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
int32_t const n_group_i32 = static_cast<int32_t>(n_group);
|
||||
int32_t const topk_group_i32 = static_cast<int32_t>(topk_group);
|
||||
int32_t const topk_i32 = static_cast<int32_t>(topk);
|
||||
int32_t const num_experts_i32 = static_cast<int32_t>(num_experts);
|
||||
|
||||
int32_t const num_warps = blockDim.x / WARP_SIZE;
|
||||
if (warp_id >= n_group_i32 || num_warps < n_group_i32) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t const num_experts_per_group = num_experts_i32 / n_group_i32;
|
||||
|
||||
T* scores_token = scores + static_cast<int64_t>(token_id) * num_experts;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
extern __shared__ char smem_buf[];
|
||||
// warpSelect internal staging buffer layout
|
||||
size_t const val_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
|
||||
size_t const val_bytes_aligned =
|
||||
warp_topk::round_up_to_multiple_of<256>(val_bytes);
|
||||
size_t const idx_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
|
||||
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
|
||||
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
// user-managed shared memory starts after warpSelect internal staging.
|
||||
uintptr_t ptr_u = reinterpret_cast<uintptr_t>(smem_buf + internal_bytes);
|
||||
ptr_u = (ptr_u + 15) & ~static_cast<uintptr_t>(15); // align to 16B
|
||||
T* s_group_scores = reinterpret_cast<T*>(ptr_u);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min =
|
||||
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
|
||||
// The check is necessary to avoid abnormal input
|
||||
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
// phase 1: per-group scan
|
||||
int32_t const group_offset = warp_id * num_experts_per_group;
|
||||
topk_with_k2<T, BiasT, SF>(s_group_scores + warp_id,
|
||||
scores_token + group_offset, bias + group_offset,
|
||||
tile, lane_id, num_experts_per_group);
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group_i32;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// phase 2: warp0 selects groups + merges candidates to final topk
|
||||
if (warp_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
topk_values += static_cast<int64_t>(token_id) * topk;
|
||||
topk_indices += static_cast<int64_t>(token_id) * topk;
|
||||
|
||||
// select topk_group groups by group score
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
group_sel(static_cast<int32_t>(topk_group_i32), neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
auto process_group = [&](int i_group) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = neg_inf<T>();
|
||||
if (i < num_experts_per_group) {
|
||||
// apply scoring function (if any) and add bias
|
||||
T input = scores[offset + i];
|
||||
if (is_finite(input)) {
|
||||
T score = apply_scoring<SF>(input);
|
||||
candidates = score + static_cast<T>(bias[offset + i]);
|
||||
}
|
||||
}
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
// all lanes must participate in WarpSelect::add().
|
||||
T gscore = (lane_id < n_group_i32) ? s_group_scores[lane_id] : neg_inf<T>();
|
||||
group_sel.add(gscore, lane_id);
|
||||
group_sel.done();
|
||||
|
||||
// proceed only if the k-th selected group score is not -inf
|
||||
bool proceed = false;
|
||||
if (topk_group_i32 > 0) {
|
||||
int const kth_lane = topk_group_i32 - 1;
|
||||
// broadcast the k-th selected group score to all lanes
|
||||
T kth_val = __shfl_sync(FULL_WARP_MASK, group_sel.get_val(0), kth_lane);
|
||||
proceed = (kth_val != neg_inf<T>());
|
||||
}
|
||||
|
||||
if (!proceed) {
|
||||
for (int i = lane_id; i < topk_i32; i += WARP_SIZE) {
|
||||
topk_indices[i] = static_cast<IdxT>(i);
|
||||
topk_values[i] = 1.0f / static_cast<float>(topk_i32);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
// merge per-group topk candidates for selected groups, then select topk
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
expert_sel(static_cast<int32_t>(topk_i32), neg_inf<T>());
|
||||
|
||||
// selected group ids reside in lanes [0, topk_group)
|
||||
int32_t sel_gid_lane = (lane_id < topk_group_i32) ? group_sel.get_idx(0) : 0;
|
||||
|
||||
// add candidates from selected groups to expert_sel
|
||||
for (int32_t g = 0; g < topk_group_i32; ++g) {
|
||||
int32_t gid = __shfl_sync(FULL_WARP_MASK, sel_gid_lane, g);
|
||||
int32_t const offset = gid * num_experts_per_group;
|
||||
int32_t const align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) {
|
||||
// all lanes must call `add()` the same number of times.
|
||||
T cand = neg_inf<T>();
|
||||
int32_t idx = 0;
|
||||
if (i < num_experts_per_group) {
|
||||
idx = offset + i;
|
||||
T input = scores_token[idx];
|
||||
if (is_finite(input)) {
|
||||
T score = apply_scoring<SF>(input);
|
||||
cand = score + static_cast<T>(bias[idx]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr (kUseStaticNGroup) {
|
||||
#pragma unroll
|
||||
for (int i_group = 0; i_group < NGroup; ++i_group) {
|
||||
process_group(i_group);
|
||||
}
|
||||
} else {
|
||||
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
|
||||
process_group(i_group);
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = cuda_cast<T, float>(0.0f);
|
||||
if (i < topk) {
|
||||
// Load the score value (without bias) for normalization
|
||||
T input = scores[s_topk_idx[i]];
|
||||
value = apply_scoring<SF>(input);
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
if (renormalize) {
|
||||
topk_sum +=
|
||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
expert_sel.add(cand, idx);
|
||||
}
|
||||
}
|
||||
expert_sel.done();
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
float scale = routed_scaling_factor;
|
||||
if (renormalize) {
|
||||
scale /= topk_sum;
|
||||
}
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||
float value = base * scale;
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = value;
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = 1.0f / topk;
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
// compute unbiased routing weights + optional renorm.
|
||||
float lane_unbiased = 0.0f;
|
||||
IdxT lane_idx = 0;
|
||||
if (lane_id < topk_i32) {
|
||||
lane_idx = static_cast<IdxT>(expert_sel.get_idx(0));
|
||||
T in = scores_token[static_cast<int32_t>(lane_idx)];
|
||||
lane_unbiased = cuda_cast<float, T>(apply_scoring<SF>(in));
|
||||
}
|
||||
|
||||
float topk_sum = 1e-20f;
|
||||
if (renormalize) {
|
||||
topk_sum += cg::reduce(tile, lane_unbiased, cg::plus<float>());
|
||||
}
|
||||
|
||||
float scale = static_cast<float>(routed_scaling_factor);
|
||||
if (renormalize) {
|
||||
scale /= topk_sum;
|
||||
}
|
||||
|
||||
if (lane_id < topk_i32) {
|
||||
topk_indices[lane_id] = lane_idx;
|
||||
topk_values[lane_id] = lane_unbiased * scale;
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
|
||||
inline void launch_group_idx_and_topk_kernel(
|
||||
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
|
||||
float* topk_values, IdxT* topk_indices, BiasT const* bias,
|
||||
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool const renormalize,
|
||||
double const routed_scaling_factor) {
|
||||
auto launch = [&](auto* kernel_instance2) {
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, bias, num_tokens, n_group,
|
||||
topk_group, topk, num_experts, num_experts_per_group,
|
||||
renormalize, routed_scaling_factor);
|
||||
};
|
||||
|
||||
switch (n_group) {
|
||||
case 4: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 4>);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 8>);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 16>);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 32>);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF>);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename BiasT, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
IdxT* topk_indices, BiasT const* bias,
|
||||
int64_t const num_tokens, int64_t const num_experts,
|
||||
int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, bool const renormalize,
|
||||
double const routed_scaling_factor, int const scoring_func,
|
||||
bool enable_pdl = false, cudaStream_t const stream = 0) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
void invokeNoAuxTc(T* scores, float* topk_values, IdxT* topk_indices,
|
||||
BiasT const* bias, int64_t const num_tokens,
|
||||
int64_t const num_experts, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk,
|
||||
bool const renormalize, double const routed_scaling_factor,
|
||||
int const scoring_func, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
// One block per token; one warp per group.
|
||||
config.gridDim = static_cast<uint32_t>(num_tokens);
|
||||
config.blockDim = static_cast<uint32_t>(n_group) * WARP_SIZE;
|
||||
// Dynamic shared memory: WarpSelect staging + per-group topk buffers.
|
||||
int32_t const num_warps = static_cast<int32_t>(n_group);
|
||||
size_t const val_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(T);
|
||||
size_t const val_bytes_aligned =
|
||||
warp_topk::round_up_to_multiple_of<256>(val_bytes);
|
||||
size_t const idx_bytes =
|
||||
static_cast<size_t>(num_warps) * WARP_SIZE * sizeof(int32_t);
|
||||
size_t const internal_bytes = val_bytes_aligned + idx_bytes;
|
||||
size_t const extra_bytes = 16 + static_cast<size_t>(n_group) * sizeof(T);
|
||||
config.dynamicSmemBytes = internal_bytes + extra_bytes;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
@@ -759,66 +687,35 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
auto const sf = static_cast<ScoringFunc>(scoring_func);
|
||||
int64_t const num_experts_per_group = num_experts / n_group;
|
||||
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
|
||||
num_tokens, num_cases, n_group, num_experts_per_group);
|
||||
};
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_NONE>;
|
||||
launch_topk_with_k2(kernel_instance1);
|
||||
break;
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_NONE>;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
return;
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_SIGMOID>;
|
||||
launch_topk_with_k2(kernel_instance1);
|
||||
break;
|
||||
auto* kernel_instance =
|
||||
&grouped_topk_fused_kernel<T, BiasT, IdxT, SCORING_SIGMOID>;
|
||||
cudaLaunchKernelEx(&config, kernel_instance, scores, topk_values,
|
||||
topk_indices, bias, num_tokens, num_experts, n_group,
|
||||
topk_group, topk, renormalize, routed_scaling_factor);
|
||||
return;
|
||||
}
|
||||
default:
|
||||
// should be guarded by higher level checks.
|
||||
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||
}
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_NONE>(
|
||||
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts,
|
||||
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||
break;
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_SIGMOID>(
|
||||
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts,
|
||||
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT>( \
|
||||
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
|
||||
BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT>( \
|
||||
T * scores, float* topk_values, IdxT* topk_indices, BiasT const* bias, \
|
||||
int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t);
|
||||
@@ -843,17 +740,21 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
int64_t num_tokens = input_size[0];
|
||||
int64_t num_experts = input_size[1];
|
||||
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
|
||||
TORCH_CHECK(n_group > 0, "n_group must be positive");
|
||||
TORCH_CHECK(topk > 0, "topk must be positive");
|
||||
TORCH_CHECK(topk_group > 0, "topk_group must be positive");
|
||||
TORCH_CHECK(topk_group <= n_group, "topk_group must be <= n_group");
|
||||
TORCH_CHECK(num_experts % n_group == 0,
|
||||
"num_experts should be divisible by n_group");
|
||||
TORCH_CHECK(n_group <= 32,
|
||||
"n_group should be smaller than or equal to 32 for now");
|
||||
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
|
||||
TORCH_CHECK(topk <= topk_group * (num_experts / n_group),
|
||||
"topk must be <= topk_group * (num_experts / n_group)");
|
||||
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
|
||||
scoring_func == vllm::moe::SCORING_SIGMOID,
|
||||
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
|
||||
|
||||
torch::Tensor group_scores = torch::empty(
|
||||
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
|
||||
// Always output float32 for topk_values (eliminates Python-side conversion)
|
||||
torch::Tensor topk_values = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
@@ -868,7 +769,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
case torch::kFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
|
||||
@@ -879,7 +779,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
case torch::kFloat32: \
|
||||
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
|
||||
@@ -890,7 +789,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
case torch::kBFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
|
||||
|
||||
Reference in New Issue
Block a user