[Quantization][MoE] remove unused ep logic from moe marlin (#31571)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin
2026-01-07 01:07:19 +08:00
committed by GitHub
parent 28c94770ad
commit 2f4bdee61e
6 changed files with 31 additions and 60 deletions

View File

@@ -7,20 +7,20 @@
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, int num_groups, int prob_m, int prob_n, \
int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {

View File

@@ -71,7 +71,6 @@ __global__ void Marlin(
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
@@ -273,7 +272,6 @@ __global__ void Marlin(
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
@@ -376,14 +374,6 @@ __global__ void Marlin(
// parallel: num valid moe blocks
int parallel = num_tokens_past_padded / moe_block_size;
int num_valid_blocks = parallel;
if (is_ep) {
for (int i = 0; i < parallel; i++) {
if (expert_ids_ptr[i] == -1) num_valid_blocks--;
}
}
int num_invalid_blocks = parallel - num_valid_blocks;
parallel = num_valid_blocks;
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
@@ -538,22 +528,8 @@ __global__ void Marlin(
if (par_id >= parallel) return;
old_expert_id = expert_id;
if (num_invalid_blocks > 0) {
int skip_count = par_id;
for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) {
expert_id = expert_ids_ptr[i];
if (expert_id != -1) {
if (skip_count == 0) {
block_id = i;
break;
};
skip_count--;
};
}
} else {
block_id = par_id;
expert_id = expert_ids_ptr[block_id];
}
block_id = par_id;
expert_id = expert_ids_ptr[block_id];
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
uint16_t val = global_scale_ptr[expert_id];

View File

@@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* perm, void* a_tmp, void* sorted_token_ids,
void* expert_ids, void* num_tokens_past_padded,
void* topk_weights, int moe_block_size, int num_experts,
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int top_k, bool mul_topk_weights, int prob_m, int prob_n,
int prob_k, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int blocks_per_sm, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
@@ -523,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on
}
@@ -541,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm(
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float, int64_t thread_k, int64_t thread_n,
@@ -855,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm(
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
mul_topk_weights, size_m, size_n, size_k, workspace.data_ptr(), a_type,
b_type, c_type, s_type, has_bias, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
@@ -866,4 +866,4 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}
}

View File

@@ -71,7 +71,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_type_id,"
"bool mul_topk_weights, int b_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float,"