[Kernel] Zero point support in fused MarlinMoE kernel + AWQ Fused MoE (#8973)
Co-authored-by: Dipika <dipikasikka1@gmail.com> Co-authored-by: Dipika Sikka <ds3822@columbia.edu>
This commit is contained in:
@@ -38,6 +38,7 @@ using FragA = Vec<half2, 4>;
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>; // quantization scales
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
@@ -175,6 +176,46 @@ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline FragB dequant<vllm::kU4.id()>(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline FragB dequant<vllm::kU8.id()>(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
@@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
__device__ inline void scale_float(float* c, FragS& s) {
|
||||
__half* s_ptr = reinterpret_cast<__half*>(&s);
|
||||
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
||||
__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
|
||||
half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
|
||||
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||
}
|
||||
|
||||
// Same as above, but for act_order (each K is multiplied individually)
|
||||
@@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
|
||||
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
|
||||
}
|
||||
|
||||
// Given 2 floats multiply by 2 scales (halves)
|
||||
__device__ inline void scale_float(float* c, FragS& s) {
|
||||
__half* s_ptr = reinterpret_cast<__half*>(&s);
|
||||
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
|
||||
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
@@ -248,10 +295,11 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
__device__ inline void MarlinMoESingle(
|
||||
__device__ void MarlinMoESingle(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
@@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle(
|
||||
const float* __restrict__ topk_weights, // float topk weights
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int* __restrict__ expert_offsets,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
@@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle(
|
||||
int tb_n_warps = thread_n_blocks / 4;
|
||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||
|
||||
constexpr int sorted_sh_stride = threads;
|
||||
constexpr int sorted_gl_stride = threads;
|
||||
// Zero-points sizes/strides
|
||||
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
constexpr int zp_tb_groups = s_tb_groups;
|
||||
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
|
||||
// Global A read index of current thread.
|
||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
@@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle(
|
||||
int s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
int zp_gl_rd;
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (group_blocks == -1) {
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
@@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
constexpr int num_row_threads = 4;
|
||||
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||
int zp_sh_rd;
|
||||
if constexpr (has_zp) {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
}
|
||||
|
||||
int sh_first_group_id = -1;
|
||||
int sh_num_groups = -1;
|
||||
constexpr int sh_max_num_groups = 32;
|
||||
|
||||
int shs_size;
|
||||
if constexpr (has_act_order)
|
||||
shs_size = sh_max_num_groups * s_sh_stride + threads;
|
||||
else
|
||||
shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
// Shared memory storage for global fetch pipelines.
|
||||
int4* sh_a = sh;
|
||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
||||
int* sh_sorted = (int*)(sh_s + shs_size);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
|
||||
// Precompute which thread should not read memory in which iterations; this is
|
||||
// needed if there are more threads than required for a certain tilesize or
|
||||
@@ -525,8 +598,10 @@ __device__ inline void MarlinMoESingle(
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
@@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks != -1) {
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch zero-points if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < zp_tb_groups; i++) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||
&zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||
@@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle(
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
// TODO we are currently hitting illegal memory accesses when fetching
|
||||
// sorted_ids to shared data: fix this
|
||||
auto fetch_sorted_ids_to_shared = [&]() {
|
||||
const int mpt = ceildiv(prob_m, threads);
|
||||
for (int i = 0; i < mpt; i++) {
|
||||
if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
|
||||
sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
|
||||
sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
|
||||
}
|
||||
auto fetch_zp_to_shared = [&]() {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
};
|
||||
|
||||
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||
// This code does not handle group_blocks == 0,
|
||||
// which signifies act_order.
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(!has_zp || group_blocks != 0);
|
||||
|
||||
if constexpr (has_zp) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||
}
|
||||
|
||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
int cur_k = warp_row * 16;
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = 0;
|
||||
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
if constexpr (has_zp) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
zp_quant_0 = frag_qzp[k % 2][0];
|
||||
zp_quant_1 = zp_quant_0 >> 8;
|
||||
} else {
|
||||
static_assert(w_type.size_bits() == 8);
|
||||
zp_quant_0 = frag_qzp[k % 2][0];
|
||||
zp_quant_1 = frag_qzp[k % 2][1];
|
||||
}
|
||||
|
||||
frag_zp_0 = dequant<w_type_id>(zp_quant_0);
|
||||
frag_zp_1 = dequant<w_type_id>(zp_quant_1);
|
||||
|
||||
frag_zp[0] = frag_zp_0[0];
|
||||
frag_zp[1] = frag_zp_0[1];
|
||||
frag_zp[2] = frag_zp_1[0];
|
||||
frag_zp[3] = frag_zp_1[1];
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle(
|
||||
|
||||
FragB frag_b0 = dequant<w_type_id>(b_quant_0);
|
||||
FragB frag_b1 = dequant<w_type_id>(b_quant_1);
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
sub_zp(frag_b0, frag_zp[j], 0);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
@@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b1
|
||||
if constexpr (has_zp) {
|
||||
sub_zp(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
|
||||
@@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle(
|
||||
|
||||
// Start global fetch and register load pipelines.
|
||||
auto start_pipes = [&]() {
|
||||
// TODO re-enable after fixing this function
|
||||
// fetch_sorted_ids_to_shared();
|
||||
// __syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < stages - 1; i++) {
|
||||
@@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle(
|
||||
}
|
||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_zp_to_shared();
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
}
|
||||
|
||||
@@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle(
|
||||
init_same_group(0);
|
||||
fetch_to_registers(0, 0);
|
||||
fetch_scales_to_registers(0, 0);
|
||||
fetch_zp_to_registers(0, 0);
|
||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||
};
|
||||
@@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle(
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
fetch_to_registers(k + 1, pipe % stages);
|
||||
fetch_scales_to_registers(k + 1, pipe);
|
||||
fetch_zp_to_registers(k + 1, pipe);
|
||||
if (k == b_sh_wr_iters - 2) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages);
|
||||
@@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle(
|
||||
|
||||
} else {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
start_pipes();
|
||||
}
|
||||
}
|
||||
@@ -1250,6 +1432,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
@@ -1261,6 +1444,8 @@ __global__ void MarlinMoE(
|
||||
const float* __restrict__ topk_weights, // float topk weights
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int* __restrict__ expert_offsets,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
@@ -1309,29 +1494,29 @@ __global__ void MarlinMoE(
|
||||
|
||||
if (max_block == 1) {
|
||||
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
} else if (max_block == 2) {
|
||||
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
} else if (max_block == 3) {
|
||||
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
} else {
|
||||
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
|
||||
stages, has_act_order, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
|
||||
stages, has_act_order, has_zp, group_blocks>(
|
||||
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
|
||||
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
|
||||
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
|
||||
current_m_block);
|
||||
@@ -1347,6 +1532,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
@@ -1358,6 +1544,8 @@ __global__ void MarlinMoE(
|
||||
const float* __restrict__ topk_weights, // float topk weights
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const int* __restrict__ expert_offsets,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
@@ -1374,7 +1562,6 @@ __global__ void MarlinMoE(
|
||||
int current_m_block, // current m block to start kernel computation from
|
||||
int max_par, // maximum parallelism
|
||||
int cfg_max_m_blocks // upper bound on m blocks
|
||||
|
||||
) {
|
||||
// Marlin is not implemented yet for SM < 8.0
|
||||
assert(false);
|
||||
@@ -1389,37 +1576,41 @@ __global__ void MarlinMoE(
|
||||
const int USER_THREADS =
|
||||
256; // Note: This is only used with user-provided thread_k/n
|
||||
const int STAGES = 4; // 4 pipeline stages fit into shared memory
|
||||
// const int SHARED_MEM =
|
||||
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
|
||||
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS, NUM_THREADS) \
|
||||
HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
||||
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||
num_threads == NUM_THREADS) { \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
|
||||
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
|
||||
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
|
||||
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
|
||||
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
|
||||
replicate_input, apply_weights, m_block, max_par, \
|
||||
cfg_max_m_blocks); \
|
||||
}
|
||||
|
||||
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
|
||||
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
} // namespace marlin_moe
|
||||
|
||||
Reference in New Issue
Block a user