diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 7229e420d..27e646bcd 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -457,8 +457,8 @@ __device__ inline T apply_scoring(T val) { } } -template -__device__ void topk_with_k2(T* output, T const* input, T const* bias, +template +__device__ void topk_with_k2(T* output, T const* input, BiasT const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, int const num_experts_per_group) { @@ -469,7 +469,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { T value = apply_scoring(input[i]); - value = value + bias[i]; + value = value + static_cast(bias[i]); if (value > largest) { second_largest = largest; @@ -481,7 +481,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { T value = apply_scoring(input[i]); - value = value + bias[i]; + value = value + static_cast(bias[i]); largest = value; } } @@ -503,8 +503,8 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } -template -__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, +template +__global__ void topk_with_k2_kernel(T* output, T* input, BiasT const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, @@ -517,7 +517,7 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, input += case_id * num_experts_per_group; // bias is per expert group, offset to current group int32_t group_id = case_id % n_group; - T const* group_bias = bias + group_id * num_experts_per_group; + BiasT const* group_bias = bias + group_id * num_experts_per_group; output += case_id; cg::thread_block block = cg::this_thread_block(); @@ -526,18 +526,19 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, group_bias, tile, lane_id, - num_experts_per_group); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, - T const* bias, int64_t const num_tokens, int64_t const n_group, + BiasT const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool renormalize, double routed_scaling_factor) { @@ -623,7 +624,7 @@ __global__ void group_idx_and_topk_idx_kernel( T input = scores[offset + i]; if (is_finite(input)) { T score = apply_scoring(input); - candidates = score + bias[offset + i]; + candidates = score + static_cast(bias[offset + i]); } } queue.add(candidates, offset + i); @@ -698,10 +699,10 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } -template +template inline void launch_group_idx_and_topk_kernel( cudaLaunchConfig_t const& config, T* scores, T* group_scores, - float* topk_values, IdxT* topk_indices, T const* bias, + float* topk_values, IdxT* topk_indices, BiasT const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool const renormalize, @@ -715,36 +716,36 @@ inline void launch_group_idx_and_topk_kernel( switch (n_group) { case 4: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } case 8: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } case 16: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } case 32: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } default: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } } } -template +template void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, - IdxT* topk_indices, T const* bias, int64_t const num_tokens, - int64_t const num_experts, int64_t const n_group, - int64_t const topk_group, int64_t const topk, - bool const renormalize, double const routed_scaling_factor, - int const scoring_func, bool enable_pdl = false, - cudaStream_t const stream = 0) { + IdxT* topk_indices, BiasT const* bias, + int64_t const num_tokens, int64_t const num_experts, + int64_t const n_group, int64_t const topk_group, + int64_t const topk, bool const renormalize, + double const routed_scaling_factor, int const scoring_func, + bool enable_pdl = false, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; cudaLaunchConfig_t config; @@ -765,12 +766,12 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, }; switch (sf) { case SCORING_NONE: { - auto* kernel_instance1 = &topk_with_k2_kernel; + auto* kernel_instance1 = &topk_with_k2_kernel; launch_topk_with_k2(kernel_instance1); break; } case SCORING_SIGMOID: { - auto* kernel_instance1 = &topk_with_k2_kernel; + auto* kernel_instance1 = &topk_with_k2_kernel; launch_topk_with_k2(kernel_instance1); break; } @@ -794,14 +795,14 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, config.attrs = attrs; switch (sf) { case SCORING_NONE: { - launch_group_idx_and_topk_kernel( + launch_group_idx_and_topk_kernel( config, scores, group_scores, topk_values, topk_indices, bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts_per_group, renormalize, routed_scaling_factor); break; } case SCORING_SIGMOID: { - launch_group_idx_and_topk_kernel( + launch_group_idx_and_topk_kernel( config, scores, group_scores, topk_values, topk_indices, bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts_per_group, renormalize, routed_scaling_factor); @@ -812,17 +813,23 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, } } -#define INSTANTIATE_NOAUX_TC(T, IdxT) \ - template void invokeNoAuxTc( \ - T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ - T const* bias, int64_t const num_tokens, int64_t const num_experts, \ - int64_t const n_group, int64_t const topk_group, int64_t const topk, \ - bool const renormalize, double const routed_scaling_factor, \ +#define INSTANTIATE_NOAUX_TC(T, BiasT, IdxT) \ + template void invokeNoAuxTc( \ + T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \ + BiasT const* bias, int64_t const num_tokens, int64_t const num_experts, \ + int64_t const n_group, int64_t const topk_group, int64_t const topk, \ + bool const renormalize, double const routed_scaling_factor, \ int const scoring_func, bool enable_pdl, cudaStream_t const stream); -INSTANTIATE_NOAUX_TC(float, int32_t); -INSTANTIATE_NOAUX_TC(half, int32_t); -INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(float, float, int32_t); +INSTANTIATE_NOAUX_TC(float, half, int32_t); +INSTANTIATE_NOAUX_TC(float, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(half, float, int32_t); +INSTANTIATE_NOAUX_TC(half, half, int32_t); +INSTANTIATE_NOAUX_TC(half, __nv_bfloat16, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, float, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, half, int32_t); +INSTANTIATE_NOAUX_TC(__nv_bfloat16, __nv_bfloat16, int32_t); } // end namespace moe } // namespace vllm @@ -831,6 +838,7 @@ std::tuple grouped_topk( int64_t topk, bool renormalize, double routed_scaling_factor, torch::Tensor const& bias, int64_t scoring_func = 0) { auto data_type = scores.scalar_type(); + auto bias_type = bias.scalar_type(); auto input_size = scores.sizes(); int64_t num_tokens = input_size[0]; int64_t num_experts = input_size[1]; @@ -854,39 +862,62 @@ std::tuple grouped_topk( auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device()); +#define LAUNCH_KERNEL(T, IdxT) \ + do { \ + switch (bias_type) { \ + case torch::kFloat16: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(group_scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + reinterpret_cast(bias.data_ptr()), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, static_cast(scoring_func), false, \ + stream); \ + break; \ + case torch::kFloat32: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(group_scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + reinterpret_cast(bias.data_ptr()), num_tokens, \ + num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, static_cast(scoring_func), false, \ + stream); \ + break; \ + case torch::kBFloat16: \ + vllm::moe::invokeNoAuxTc( \ + reinterpret_cast(scores.mutable_data_ptr()), \ + reinterpret_cast(group_scores.mutable_data_ptr()), \ + reinterpret_cast(topk_values.mutable_data_ptr()), \ + reinterpret_cast(topk_indices.mutable_data_ptr()), \ + reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), \ + num_tokens, num_experts, n_group, topk_group, topk, renormalize, \ + routed_scaling_factor, static_cast(scoring_func), false, \ + stream); \ + break; \ + default: \ + throw std::invalid_argument( \ + "Invalid bias dtype, only supports float16, float32, and " \ + "bfloat16"); \ + break; \ + } \ + } while (0) + switch (data_type) { case torch::kFloat16: // Handle Float16 - vllm::moe::invokeNoAuxTc( - reinterpret_cast(scores.mutable_data_ptr()), - reinterpret_cast(group_scores.mutable_data_ptr()), - reinterpret_cast(topk_values.mutable_data_ptr()), - reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast(bias.data_ptr()), num_tokens, - num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, static_cast(scoring_func), false, stream); + LAUNCH_KERNEL(half, int32_t); break; case torch::kFloat32: // Handle Float32 - vllm::moe::invokeNoAuxTc( - reinterpret_cast(scores.mutable_data_ptr()), - reinterpret_cast(group_scores.mutable_data_ptr()), - reinterpret_cast(topk_values.mutable_data_ptr()), - reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast(bias.data_ptr()), num_tokens, - num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, static_cast(scoring_func), false, stream); + LAUNCH_KERNEL(float, int32_t); break; case torch::kBFloat16: // Handle BFloat16 - vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( - reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), - reinterpret_cast(topk_values.mutable_data_ptr()), - reinterpret_cast(topk_indices.mutable_data_ptr()), - reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens, - num_experts, n_group, topk_group, topk, renormalize, - routed_scaling_factor, static_cast(scoring_func), false, stream); + LAUNCH_KERNEL(__nv_bfloat16, int32_t); break; default: // Handle other data types @@ -894,5 +925,6 @@ std::tuple grouped_topk( "Invalid dtype, only supports float16, float32, and bfloat16"); break; } +#undef LAUNCH_KERNEL return {topk_values, topk_indices}; } diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index 96d793d85..f676cc4fe 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -34,7 +34,8 @@ from vllm.utils.torch_utils import set_random_seed @pytest.mark.parametrize("topk_group", [2]) @pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("bias_dtype", [torch.float32]) def test_grouped_topk( monkeypatch: pytest.MonkeyPatch, n_token: int, @@ -46,7 +47,8 @@ def test_grouped_topk( topk_group: int, scoring_func: str, routed_scaling_factor: float, - dtype: torch.dtype, + input_dtype: torch.dtype, + bias_dtype: torch.dtype, ): vllm_config = VllmConfig( compilation_config=CompilationConfig(custom_ops=["all", "+grouped_topk"]) @@ -54,11 +56,9 @@ def test_grouped_topk( get_cached_compilation_config.cache_clear() set_random_seed(0) - hidden_states = torch.randn((n_token, n_hidden), dtype=dtype, device="cuda") - gating_output = torch.randn((n_token, n_expert), dtype=dtype, device="cuda") - e_score_correction_bias = torch.randn( - (n_expert,), dtype=torch.float32, device="cuda" - ) + hidden_states = torch.randn((n_token, n_hidden), dtype=input_dtype, device="cuda") + gating_output = torch.randn((n_token, n_expert), dtype=input_dtype, device="cuda") + e_score_correction_bias = torch.randn((n_expert,), dtype=bias_dtype, device="cuda") with set_current_vllm_config(vllm_config), monkeypatch.context() as m: m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1e2f96c86..c4047401c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1627,7 +1627,7 @@ def fused_grouped_topk( topk, renormalize, routed_scaling_factor, - e_score_correction_bias.to(gating_output.dtype), + e_score_correction_bias, 1, # scoring_func=1 for sigmoid ) elif scoring_func == "softmax": @@ -1641,7 +1641,7 @@ def fused_grouped_topk( topk, renormalize, routed_scaling_factor, - e_score_correction_bias.to(gating_output.dtype), + e_score_correction_bias, 0, # scoring_func=0 (no activation, scores already computed) ) else: