|
|
|
|
@@ -427,11 +427,29 @@ __device__ inline bool is_finite(const T val) {
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Scoring function enums
|
|
|
|
|
enum ScoringFunc {
|
|
|
|
|
SCORING_NONE = 0, // no activation function
|
|
|
|
|
SCORING_SIGMOID = 1 // apply sigmoid
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Efficient sigmoid approximation from TensorRT-LLM
|
|
|
|
|
__device__ inline float sigmoid_accurate(float x) {
|
|
|
|
|
return 0.5f * tanhf(0.5f * x) + 0.5f;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ void topk_with_k2(T* output, T const* input,
|
|
|
|
|
__device__ inline T apply_sigmoid(T val) {
|
|
|
|
|
float f = cuda_cast<float, T>(val);
|
|
|
|
|
return cuda_cast<T, float>(sigmoid_accurate(f));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
|
|
|
|
cg::thread_block_tile<32> const& tile,
|
|
|
|
|
int32_t const lane_id,
|
|
|
|
|
int const num_experts_per_group) {
|
|
|
|
|
int const num_experts_per_group,
|
|
|
|
|
int const scoring_func) {
|
|
|
|
|
// Get the top2 per thread
|
|
|
|
|
T largest = neg_inf<T>();
|
|
|
|
|
T second_largest = neg_inf<T>();
|
|
|
|
|
@@ -439,6 +457,12 @@ __device__ void topk_with_k2(T* output, T const* input,
|
|
|
|
|
if (num_experts_per_group > WARP_SIZE) {
|
|
|
|
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
|
|
|
|
T value = input[i];
|
|
|
|
|
// Apply scoring function if needed
|
|
|
|
|
if (scoring_func == SCORING_SIGMOID) {
|
|
|
|
|
value = apply_sigmoid(value);
|
|
|
|
|
}
|
|
|
|
|
value = value + bias[i];
|
|
|
|
|
|
|
|
|
|
if (value > largest) {
|
|
|
|
|
second_largest = largest;
|
|
|
|
|
largest = value;
|
|
|
|
|
@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
|
|
|
|
largest = input[i];
|
|
|
|
|
T value = input[i];
|
|
|
|
|
// Apply scoring function if needed
|
|
|
|
|
if (scoring_func == SCORING_SIGMOID) {
|
|
|
|
|
value = apply_sigmoid(value);
|
|
|
|
|
}
|
|
|
|
|
value = value + bias[i];
|
|
|
|
|
largest = value;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void topk_with_k2_kernel(T* output, T* input,
|
|
|
|
|
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
|
|
|
|
int64_t const num_tokens,
|
|
|
|
|
int64_t const num_cases,
|
|
|
|
|
int64_t const n_group,
|
|
|
|
|
int64_t const num_experts_per_group) {
|
|
|
|
|
int64_t const num_experts_per_group,
|
|
|
|
|
int const scoring_func) {
|
|
|
|
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
|
|
|
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
|
|
|
|
|
|
|
|
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
|
|
|
|
if (case_id < num_cases) {
|
|
|
|
|
input += case_id * num_experts_per_group;
|
|
|
|
|
// bias is per expert group, offset to current group
|
|
|
|
|
int32_t group_id = case_id % n_group;
|
|
|
|
|
T const* group_bias = bias + group_id * num_experts_per_group;
|
|
|
|
|
output += case_id;
|
|
|
|
|
|
|
|
|
|
cg::thread_block block = cg::this_thread_block();
|
|
|
|
|
@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
|
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
|
|
|
asm volatile("griddepcontrol.wait;");
|
|
|
|
|
#endif
|
|
|
|
|
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
|
|
|
|
topk_with_k2(output, input, group_bias, tile, lane_id,
|
|
|
|
|
num_experts_per_group, scoring_func);
|
|
|
|
|
}
|
|
|
|
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
|
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
|
|
|
@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
|
|
|
|
|
|
|
|
|
|
template <typename T, typename IdxT>
|
|
|
|
|
__global__ void group_idx_and_topk_idx_kernel(
|
|
|
|
|
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices,
|
|
|
|
|
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group,
|
|
|
|
|
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
|
|
|
|
|
T const* bias, int64_t const num_tokens, int64_t const n_group,
|
|
|
|
|
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
|
|
|
|
int64_t const num_experts_per_group, bool renormalize,
|
|
|
|
|
double routed_scaling_factor) {
|
|
|
|
|
double routed_scaling_factor, int scoring_func) {
|
|
|
|
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
|
|
|
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
|
|
|
|
int32_t case_id =
|
|
|
|
|
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
|
|
|
|
scores_with_bias += case_id * num_experts;
|
|
|
|
|
scores += case_id * num_experts;
|
|
|
|
|
group_scores += case_id * n_group;
|
|
|
|
|
topk_values += case_id * topk;
|
|
|
|
|
@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|
|
|
|
int32_t offset = i_group * num_experts_per_group;
|
|
|
|
|
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
|
|
|
|
i += WARP_SIZE) {
|
|
|
|
|
T candidates = (i < num_experts_per_group) &&
|
|
|
|
|
is_finite(scores_with_bias[offset + i])
|
|
|
|
|
? scores_with_bias[offset + i]
|
|
|
|
|
: neg_inf<T>();
|
|
|
|
|
T candidates = neg_inf<T>();
|
|
|
|
|
if (i < num_experts_per_group) {
|
|
|
|
|
// Apply scoring function (if any) and add bias
|
|
|
|
|
T input = scores[offset + i];
|
|
|
|
|
if (is_finite(input)) {
|
|
|
|
|
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
|
|
|
|
|
: input;
|
|
|
|
|
candidates = score + bias[offset + i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
queue.add(candidates, offset + i);
|
|
|
|
|
}
|
|
|
|
|
if (group_scores[i_group] == topk_group_value) {
|
|
|
|
|
@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|
|
|
|
for (int i = lane_id;
|
|
|
|
|
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
|
|
|
|
i += WARP_SIZE) {
|
|
|
|
|
T value =
|
|
|
|
|
i < topk
|
|
|
|
|
? scores[s_topk_idx[i]]
|
|
|
|
|
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
|
|
|
|
|
T value = cuda_cast<T, float>(0.0f);
|
|
|
|
|
if (i < topk) {
|
|
|
|
|
// Load the score value (without bias) for normalization
|
|
|
|
|
T input = scores[s_topk_idx[i]];
|
|
|
|
|
value =
|
|
|
|
|
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
|
|
|
|
|
s_topk_value[i] = value;
|
|
|
|
|
}
|
|
|
|
|
topk_sum +=
|
|
|
|
|
@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|
|
|
|
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
|
|
|
|
}
|
|
|
|
|
topk_indices[i] = s_topk_idx[i];
|
|
|
|
|
topk_values[i] = cuda_cast<T, float>(value);
|
|
|
|
|
topk_values[i] = value;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
|
|
|
|
topk_indices[i] = i;
|
|
|
|
|
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
|
|
|
|
topk_values[i] = 1.0f / topk;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
|
|
|
|
@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, typename IdxT>
|
|
|
|
|
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
|
|
|
|
|
IdxT* topk_indices, T* scores_with_bias,
|
|
|
|
|
int64_t const num_tokens, int64_t const num_experts,
|
|
|
|
|
int64_t const n_group, int64_t const topk_group,
|
|
|
|
|
int64_t const topk, bool const renormalize,
|
|
|
|
|
double const routed_scaling_factor, bool enable_pdl = false,
|
|
|
|
|
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
|
|
|
|
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
|
|
|
|
|
int64_t const num_experts, int64_t const n_group,
|
|
|
|
|
int64_t const topk_group, int64_t const topk,
|
|
|
|
|
bool const renormalize, double const routed_scaling_factor,
|
|
|
|
|
int const scoring_func, bool enable_pdl = false,
|
|
|
|
|
cudaStream_t const stream = 0) {
|
|
|
|
|
int64_t num_cases = num_tokens * n_group;
|
|
|
|
|
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
|
|
|
|
@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
|
|
|
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
|
|
|
config.numAttrs = 1;
|
|
|
|
|
config.attrs = attrs;
|
|
|
|
|
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
|
|
|
|
num_tokens, num_cases, n_group, num_experts / n_group);
|
|
|
|
|
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
|
|
|
|
|
num_tokens, num_cases, n_group, num_experts / n_group,
|
|
|
|
|
scoring_func);
|
|
|
|
|
|
|
|
|
|
int64_t topk_with_k_group_num_blocks =
|
|
|
|
|
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
|
|
|
|
@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
|
|
|
|
|
config.numAttrs = 1;
|
|
|
|
|
config.attrs = attrs;
|
|
|
|
|
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
|
|
|
|
topk_values, topk_indices, scores_with_bias, num_tokens,
|
|
|
|
|
n_group, topk_group, topk, num_experts,
|
|
|
|
|
num_experts / n_group, renormalize, routed_scaling_factor);
|
|
|
|
|
topk_values, topk_indices, bias, num_tokens, n_group,
|
|
|
|
|
topk_group, topk, num_experts, num_experts / n_group,
|
|
|
|
|
renormalize, routed_scaling_factor, scoring_func);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
|
|
|
|
template void invokeNoAuxTc<T, IdxT>( \
|
|
|
|
|
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
|
|
|
|
|
T * scores_with_bias, int64_t const num_tokens, \
|
|
|
|
|
int64_t const num_experts, int64_t const n_group, \
|
|
|
|
|
int64_t const topk_group, int64_t const topk, bool const renormalize, \
|
|
|
|
|
double const routed_scaling_factor, bool enable_pdl, \
|
|
|
|
|
cudaStream_t const stream);
|
|
|
|
|
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
|
|
|
|
|
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
|
|
|
|
|
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
|
|
|
|
bool const renormalize, double const routed_scaling_factor, \
|
|
|
|
|
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
|
|
|
|
|
|
|
|
|
|
INSTANTIATE_NOAUX_TC(float, int32_t);
|
|
|
|
|
INSTANTIATE_NOAUX_TC(half, int32_t);
|
|
|
|
|
@@ -703,28 +744,32 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
|
|
|
|
|
} // namespace vllm
|
|
|
|
|
|
|
|
|
|
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
|
|
|
|
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
|
|
|
|
|
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
|
|
|
|
|
double routed_scaling_factor) {
|
|
|
|
|
auto data_type = scores_with_bias.scalar_type();
|
|
|
|
|
auto input_size = scores_with_bias.sizes();
|
|
|
|
|
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
|
|
|
|
|
int64_t topk, bool renormalize, double routed_scaling_factor,
|
|
|
|
|
torch::Tensor const& bias, int64_t scoring_func = 0) {
|
|
|
|
|
auto data_type = scores.scalar_type();
|
|
|
|
|
auto input_size = scores.sizes();
|
|
|
|
|
int64_t num_tokens = input_size[0];
|
|
|
|
|
int64_t num_experts = input_size[1];
|
|
|
|
|
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor");
|
|
|
|
|
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
|
|
|
|
|
TORCH_CHECK(num_experts % n_group == 0,
|
|
|
|
|
"num_experts should be divisible by n_group");
|
|
|
|
|
TORCH_CHECK(n_group <= 32,
|
|
|
|
|
"n_group should be smaller than or equal to 32 for now");
|
|
|
|
|
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
|
|
|
|
|
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
|
|
|
|
|
scoring_func == vllm::moe::SCORING_SIGMOID,
|
|
|
|
|
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
|
|
|
|
|
|
|
|
|
|
torch::Tensor group_scores = torch::empty(
|
|
|
|
|
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
|
|
|
|
|
// Always output float32 for topk_values (eliminates Python-side conversion)
|
|
|
|
|
torch::Tensor topk_values = torch::empty(
|
|
|
|
|
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA));
|
|
|
|
|
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
|
|
|
|
torch::Tensor topk_indices = torch::empty(
|
|
|
|
|
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
|
|
|
|
|
|
|
|
|
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device());
|
|
|
|
|
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
|
|
|
|
|
|
|
|
|
|
switch (data_type) {
|
|
|
|
|
case torch::kFloat16:
|
|
|
|
|
@@ -732,11 +777,11 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
|
|
|
|
vllm::moe::invokeNoAuxTc<half, int32_t>(
|
|
|
|
|
reinterpret_cast<half*>(scores.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<half*>(topk_values.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens,
|
|
|
|
|
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
|
|
|
|
|
num_experts, n_group, topk_group, topk, renormalize,
|
|
|
|
|
routed_scaling_factor, false, stream);
|
|
|
|
|
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
|
|
|
|
break;
|
|
|
|
|
case torch::kFloat32:
|
|
|
|
|
// Handle Float32
|
|
|
|
|
@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
|
|
|
|
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens,
|
|
|
|
|
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
|
|
|
|
|
num_experts, n_group, topk_group, topk, renormalize,
|
|
|
|
|
routed_scaling_factor, false, stream);
|
|
|
|
|
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
|
|
|
|
break;
|
|
|
|
|
case torch::kBFloat16:
|
|
|
|
|
// Handle BFloat16
|
|
|
|
|
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
|
|
|
|
|
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
|
|
|
|
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()),
|
|
|
|
|
num_tokens, num_experts, n_group, topk_group, topk, renormalize,
|
|
|
|
|
routed_scaling_factor, false, stream);
|
|
|
|
|
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
|
|
|
|
|
num_experts, n_group, topk_group, topk, renormalize,
|
|
|
|
|
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
// Handle other data types
|
|
|
|
|
|