[Quantization][MoE] remove unused ep logic from moe marlin (#31571)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -7,20 +7,20 @@
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, int num_groups, int prob_m, int prob_n, \
|
||||
int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
@@ -71,7 +71,6 @@ __global__ void Marlin(
|
||||
const float* __restrict__ topk_weights_ptr, // moe top weights
|
||||
int top_k, // num of experts per token
|
||||
bool mul_topk_weights, // mul topk weights or not
|
||||
bool is_ep, // expert parallelism
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
@@ -273,7 +272,6 @@ __global__ void Marlin(
|
||||
const float* __restrict__ topk_weights_ptr, // moe top weights
|
||||
int top_k, // num of experts per token
|
||||
bool mul_topk_weights, // mul topk weights or not
|
||||
bool is_ep, // expert parallelism
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
@@ -376,14 +374,6 @@ __global__ void Marlin(
|
||||
|
||||
// parallel: num valid moe blocks
|
||||
int parallel = num_tokens_past_padded / moe_block_size;
|
||||
int num_valid_blocks = parallel;
|
||||
if (is_ep) {
|
||||
for (int i = 0; i < parallel; i++) {
|
||||
if (expert_ids_ptr[i] == -1) num_valid_blocks--;
|
||||
}
|
||||
}
|
||||
int num_invalid_blocks = parallel - num_valid_blocks;
|
||||
parallel = num_valid_blocks;
|
||||
|
||||
int k_tiles = prob_k / 16 / thread_k_blocks;
|
||||
int n_tiles = prob_n / 16 / thread_n_blocks;
|
||||
@@ -538,22 +528,8 @@ __global__ void Marlin(
|
||||
if (par_id >= parallel) return;
|
||||
|
||||
old_expert_id = expert_id;
|
||||
if (num_invalid_blocks > 0) {
|
||||
int skip_count = par_id;
|
||||
for (int i = 0; i < num_tokens_past_padded / moe_block_size; i++) {
|
||||
expert_id = expert_ids_ptr[i];
|
||||
if (expert_id != -1) {
|
||||
if (skip_count == 0) {
|
||||
block_id = i;
|
||||
break;
|
||||
};
|
||||
skip_count--;
|
||||
};
|
||||
}
|
||||
} else {
|
||||
block_id = par_id;
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
}
|
||||
block_id = par_id;
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
uint16_t val = global_scale_ptr[expert_id];
|
||||
|
||||
@@ -336,14 +336,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* perm, void* a_tmp, void* sorted_token_ids,
|
||||
void* expert_ids, void* num_tokens_past_padded,
|
||||
void* topk_weights, int moe_block_size, int num_experts,
|
||||
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
|
||||
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
|
||||
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int dev, cudaStream_t stream,
|
||||
int thread_k, int thread_n, int sms, int blocks_per_sm,
|
||||
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
|
||||
int top_k, bool mul_topk_weights, int prob_m, int prob_n,
|
||||
int prob_k, void* workspace, vllm::ScalarType const& a_type,
|
||||
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
|
||||
vllm::ScalarType const& s_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||
int thread_n, int sms, int blocks_per_sm, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
@@ -523,7 +523,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
@@ -541,7 +541,7 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float, int64_t thread_k, int64_t thread_n,
|
||||
@@ -855,9 +855,9 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
|
||||
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
|
||||
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
mul_topk_weights, size_m, size_n, size_k, workspace.data_ptr(), a_type,
|
||||
b_type, c_type, s_type, has_bias, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float);
|
||||
|
||||
@@ -866,4 +866,4 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_type_id,"
|
||||
"bool mul_topk_weights, int b_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float,"
|
||||
|
||||
Reference in New Issue
Block a user