[Kernel] Support bias type in grouped_topk kernel (#31781)
Signed-off-by: Xin Yang <xyangx@amazon.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -457,8 +457,8 @@ __device__ inline T apply_scoring(T val) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ScoringFunc SF>
|
||||
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
template <typename T, typename BiasT, ScoringFunc SF>
|
||||
__device__ void topk_with_k2(T* output, T const* input, BiasT const* bias,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
@@ -469,7 +469,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = apply_scoring<SF>(input[i]);
|
||||
value = value + bias[i];
|
||||
value = value + static_cast<T>(bias[i]);
|
||||
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
@@ -481,7 +481,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = apply_scoring<SF>(input[i]);
|
||||
value = value + bias[i];
|
||||
value = value + static_cast<T>(bias[i]);
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
@@ -503,8 +503,8 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ScoringFunc SF>
|
||||
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
||||
template <typename T, typename BiasT, ScoringFunc SF>
|
||||
__global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
@@ -517,7 +517,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
||||
input += case_id * num_experts_per_group;
|
||||
// bias is per expert group, offset to current group
|
||||
int32_t group_id = case_id % n_group;
|
||||
T const* group_bias = bias + group_id * num_experts_per_group;
|
||||
BiasT const* group_bias = bias + group_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
@@ -526,18 +526,19 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
|
||||
num_experts_per_group);
|
||||
topk_with_k2<T, BiasT, SF>(output, input, group_bias, tile, lane_id,
|
||||
num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF,
|
||||
int NGroup = -1>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
|
||||
T const* bias, int64_t const num_tokens, int64_t const n_group,
|
||||
BiasT const* bias, int64_t const num_tokens, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
@@ -623,7 +624,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
T input = scores[offset + i];
|
||||
if (is_finite(input)) {
|
||||
T score = apply_scoring<SF>(input);
|
||||
candidates = score + bias[offset + i];
|
||||
candidates = score + static_cast<T>(bias[offset + i]);
|
||||
}
|
||||
}
|
||||
queue.add(candidates, offset + i);
|
||||
@@ -698,10 +699,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, ScoringFunc SF>
|
||||
template <typename T, typename BiasT, typename IdxT, ScoringFunc SF>
|
||||
inline void launch_group_idx_and_topk_kernel(
|
||||
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
|
||||
float* topk_values, IdxT* topk_indices, T const* bias,
|
||||
float* topk_values, IdxT* topk_indices, BiasT const* bias,
|
||||
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool const renormalize,
|
||||
@@ -715,36 +716,36 @@ inline void launch_group_idx_and_topk_kernel(
|
||||
|
||||
switch (n_group) {
|
||||
case 4: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 4>);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 8>);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 16>);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF, 32>);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
|
||||
launch(&group_idx_and_topk_idx_kernel<T, BiasT, IdxT, SF>);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
template <typename T, typename BiasT, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
|
||||
int64_t const num_experts, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk,
|
||||
bool const renormalize, double const routed_scaling_factor,
|
||||
int const scoring_func, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
IdxT* topk_indices, BiasT const* bias,
|
||||
int64_t const num_tokens, int64_t const num_experts,
|
||||
int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, bool const renormalize,
|
||||
double const routed_scaling_factor, int const scoring_func,
|
||||
bool enable_pdl = false, cudaStream_t const stream = 0) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
cudaLaunchConfig_t config;
|
||||
@@ -765,12 +766,12 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
};
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_NONE>;
|
||||
launch_topk_with_k2(kernel_instance1);
|
||||
break;
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, BiasT, SCORING_SIGMOID>;
|
||||
launch_topk_with_k2(kernel_instance1);
|
||||
break;
|
||||
}
|
||||
@@ -794,14 +795,14 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
config.attrs = attrs;
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
|
||||
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_NONE>(
|
||||
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts,
|
||||
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||
break;
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
|
||||
launch_group_idx_and_topk_kernel<T, BiasT, IdxT, SCORING_SIGMOID>(
|
||||
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts,
|
||||
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||
@@ -812,17 +813,23 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>( \
|
||||
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
|
||||
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \
|
||||
template void invokeNoAuxTc<T, BiasT, IdxT>( \
|
||||
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
|
||||
BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t);
|
||||
} // end namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
@@ -831,6 +838,7 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
int64_t topk, bool renormalize, double routed_scaling_factor,
|
||||
torch::Tensor const& bias, int64_t scoring_func = 0) {
|
||||
auto data_type = scores.scalar_type();
|
||||
auto bias_type = bias.scalar_type();
|
||||
auto input_size = scores.sizes();
|
||||
int64_t num_tokens = input_size[0];
|
||||
int64_t num_experts = input_size[1];
|
||||
@@ -854,39 +862,62 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
|
||||
|
||||
#define LAUNCH_KERNEL(T, IdxT) \
|
||||
do { \
|
||||
switch (bias_type) { \
|
||||
case torch::kFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, half, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
case torch::kFloat32: \
|
||||
vllm::moe::invokeNoAuxTc<T, float, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens, \
|
||||
num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
case torch::kBFloat16: \
|
||||
vllm::moe::invokeNoAuxTc<T, __nv_bfloat16, IdxT>( \
|
||||
reinterpret_cast<T*>(scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<T*>(group_scores.mutable_data_ptr()), \
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()), \
|
||||
reinterpret_cast<IdxT*>(topk_indices.mutable_data_ptr()), \
|
||||
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \
|
||||
num_tokens, num_experts, n_group, topk_group, topk, renormalize, \
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, \
|
||||
stream); \
|
||||
break; \
|
||||
default: \
|
||||
throw std::invalid_argument( \
|
||||
"Invalid bias dtype, only supports float16, float32, and " \
|
||||
"bfloat16"); \
|
||||
break; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
switch (data_type) {
|
||||
case torch::kFloat16:
|
||||
// Handle Float16
|
||||
vllm::moe::invokeNoAuxTc<half, int32_t>(
|
||||
reinterpret_cast<half*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
||||
LAUNCH_KERNEL(half, int32_t);
|
||||
break;
|
||||
case torch::kFloat32:
|
||||
// Handle Float32
|
||||
vllm::moe::invokeNoAuxTc<float, int32_t>(
|
||||
reinterpret_cast<float*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
||||
LAUNCH_KERNEL(float, int32_t);
|
||||
break;
|
||||
case torch::kBFloat16:
|
||||
// Handle BFloat16
|
||||
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
|
||||
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
||||
LAUNCH_KERNEL(__nv_bfloat16, int32_t);
|
||||
break;
|
||||
default:
|
||||
// Handle other data types
|
||||
@@ -894,5 +925,6 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
"Invalid dtype, only supports float16, float32, and bfloat16");
|
||||
break;
|
||||
}
|
||||
#undef LAUNCH_KERNEL
|
||||
return {topk_values, topk_indices};
|
||||
}
|
||||
|
||||
@@ -34,7 +34,8 @@ from vllm.utils.torch_utils import set_random_seed
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32])
|
||||
@pytest.mark.parametrize("bias_dtype", [torch.float32])
|
||||
def test_grouped_topk(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
n_token: int,
|
||||
@@ -46,7 +47,8 @@ def test_grouped_topk(
|
||||
topk_group: int,
|
||||
scoring_func: str,
|
||||
routed_scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
input_dtype: torch.dtype,
|
||||
bias_dtype: torch.dtype,
|
||||
):
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"])
|
||||
@@ -54,11 +56,9 @@ def test_grouped_topk(
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
set_random_seed(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn(
|
||||
(n_expert,), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
hidden_states = torch.randn((n_token, n_hidden), dtype=input_dtype, device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert), dtype=input_dtype, device="cuda")
|
||||
e_score_correction_bias = torch.randn((n_expert,), dtype=bias_dtype, device="cuda")
|
||||
|
||||
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
|
||||
@@ -1627,7 +1627,7 @@ def fused_grouped_topk(
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
e_score_correction_bias,
|
||||
1, # scoring_func=1 for sigmoid
|
||||
)
|
||||
elif scoring_func == "softmax":
|
||||
@@ -1641,7 +1641,7 @@ def fused_grouped_topk(
|
||||
topk,
|
||||
renormalize,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias.to(gating_output.dtype),
|
||||
e_score_correction_bias,
|
||||
0, # scoring_func=0 (no activation, scores already computed)
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user