diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 57f5a1793..eb83df22c 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,20 +7,20 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ b_bias_ptr, \ - const float *__restrict__ a_scales_ptr, \ - const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ global_scale_ptr, \ - const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ global_scale_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, int num_groups, int prob_m, int prob_n, \ + int prob_k, int *locks, bool has_bias, bool use_atomic_add, \ bool use_fp32_reduce namespace MARLIN_NAMESPACE_NAME { diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 138197b76..5aac69b5c 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -71,7 +71,6 @@ __global__ void Marlin( const float* __restrict__ topk_weights_ptr, // moe top weights int top_k, // num of experts per token bool mul_topk_weights, // mul topk weights or not - bool is_ep, // expert parallelism int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -273,7 +272,6 @@ __global__ void Marlin( const float* __restrict__ topk_weights_ptr, // moe top weights int top_k, // num of experts per token bool mul_topk_weights, // mul topk weights or not - bool is_ep, // expert parallelism int num_groups, // number of scale groups per output channel int prob_m, // batch dimension m int prob_n, // output dimension n @@ -376,14 +374,6 @@ __global__ void Marlin( // parallel: num valid moe blocks int parallel = num_tokens_past_padded / moe_block_size; - int num_valid_blocks = parallel; - if (is_ep) { - for (int i = 0; i < parallel; i++) { - if (expert_ids_ptr[i] == -1) num_valid_blocks--; - } - } - int num_invalid_blocks = parallel - num_valid_blocks; - parallel = num_valid_blocks; int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; @@ -538,22 +528,8 @@ __global__ void Marlin( if (par_id >= parallel) return; old_expert_id = expert_id; - if (num_invalid_blocks > 0) { - int skip_count = par_id; - for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) { - expert_id = expert_ids_ptr[i]; - if (expert_id != -1) { - if (skip_count == 0) { - block_id = i; - break; - }; - skip_count--; - }; - } - } else { - block_id = par_id; - expert_id = expert_ids_ptr[block_id]; - } + block_id = par_id; + expert_id = expert_ids_ptr[block_id]; if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { uint16_t val = global_scale_ptr[expert_id]; diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index 8ac169122..00b17f075 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, void* perm, void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int num_experts, - int top_k, bool mul_topk_weights, bool is_ep, int prob_m, - int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& a_type, vllm::ScalarType const& b_type, - vllm::ScalarType const& c_type, vllm::ScalarType const& s_type, - bool has_bias, bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int blocks_per_sm, - bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { + int top_k, bool mul_topk_weights, int prob_m, int prob_n, + int prob_k, void* workspace, vllm::ScalarType const& a_type, + vllm::ScalarType const& b_type, vllm::ScalarType const& c_type, + vllm::ScalarType const& s_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int blocks_per_sm, bool use_atomic_add, + bool use_fp32_reduce, bool is_zp_float) { int thread_m_blocks = div_ceil(moe_block_size, 16); bool m_block_size_8 = moe_block_size == 8; bool is_a_8bit = a_type.size_bits() == 8; @@ -523,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, kernel<<>>( A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, - topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, + topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m, prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce); // clang-format on } @@ -541,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm( std::optional const& perm_or_none, torch::Tensor& workspace, torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids, torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights, - int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep, + int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float, int64_t thread_k, int64_t thread_n, @@ -855,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm( perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, num_experts, top_k, - mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(), - a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + mul_topk_weights, size_m, size_n, size_k, workspace.data_ptr(), a_type, + b_type, c_type, s_type, has_bias, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce, is_zp_float); @@ -866,4 +866,4 @@ torch::Tensor moe_wna16_marlin_gemm( TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); -} \ No newline at end of file +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 779ad70ad..7d44db21d 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -71,7 +71,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," "Tensor! topk_weights, int moe_block_size, int top_k, " - "bool mul_topk_weights, bool is_ep, int b_type_id," + "bool mul_topk_weights, int b_type_id," "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," "bool use_fp32_reduce, bool is_zp_float," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c1519fc17..0d6d545fe 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2162,7 +2162,6 @@ def moe_wna16_marlin_gemm( moe_block_size: int, top_k: int, mul_topk_weights: bool, - is_ep: bool, b_q_type: ScalarType, size_m: int, size_n: int, @@ -2194,7 +2193,6 @@ def moe_wna16_marlin_gemm( moe_block_size, top_k, mul_topk_weights, - is_ep, b_q_type.id, size_m, size_n, @@ -2256,7 +2254,6 @@ if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe") moe_block_size: int, top_k: int, mul_topk_weights: bool, - is_ep: bool, b_q_type: ScalarType, size_m: int, size_n: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 295a2a281..c031d9efc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -135,7 +135,6 @@ def _fused_marlin_moe( moe_block_size=block_size_m, top_k=num_topk, mul_topk_weights=apply_router_weight_on_input, - is_ep=expert_map is not None, b_q_type=quant_type, size_m=M, size_n=2 * N, @@ -187,7 +186,6 @@ def _fused_marlin_moe( moe_block_size=block_size_m, top_k=1, mul_topk_weights=not apply_router_weight_on_input, - is_ep=expert_map is not None, b_q_type=quant_type, size_m=M * num_topk, size_n=K,