[Quantization][MoE] remove unused ep logic from moe marlin (#31571)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Jinzhen Lin
2026-01-07 01:07:19 +08:00
committed by GitHub
parent 28c94770ad
commit 2f4bdee61e
6 changed files with 31 additions and 60 deletions

View File

@@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* perm, void* a_tmp, void* sorted_token_ids,
void* expert_ids, void* num_tokens_past_padded,
void* topk_weights, int moe_block_size, int num_experts,
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int top_k, bool mul_topk_weights, int prob_m, int prob_n,
int prob_k, void* workspace, vllm::ScalarType const& a_type,
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
vllm::ScalarType const& s_type, bool has_bias,
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int blocks_per_sm, bool use_atomic_add,
bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
@@ -523,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on
}
@@ -541,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm(
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float, int64_t thread_k, int64_t thread_n,
@@ -855,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm(
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
mul_topk_weights, size_m, size_n, size_k, workspace.data_ptr(), a_type,
b_type, c_type, s_type, has_bias, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
@@ -866,4 +866,4 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}
}