[Model][Quantization] HQQ support through Marlin kernel expansion (#9766)
Signed-off-by: ElizaWszola <eliza@neuralmagic.com>
This commit is contained in:
@@ -54,9 +54,10 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
@@ -82,7 +83,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
bool is_k_full, bool has_zp, bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
@@ -516,10 +517,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
@@ -692,8 +694,10 @@ __global__ void Marlin(
|
||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||
|
||||
// Zero-points sizes/strides
|
||||
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = is_zp_float
|
||||
? 16 * thread_n_blocks / 8
|
||||
: ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
constexpr int zp_tb_groups = s_tb_groups;
|
||||
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
@@ -768,9 +772,16 @@ __global__ void Marlin(
|
||||
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||
int zp_sh_rd;
|
||||
if constexpr (has_zp) {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
if constexpr (is_zp_float) {
|
||||
if constexpr (group_blocks != -1) {
|
||||
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
}
|
||||
} else {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
}
|
||||
}
|
||||
|
||||
// Precompute which thread should not read memory in which iterations; this is
|
||||
@@ -832,6 +843,7 @@ __global__ void Marlin(
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
@@ -1126,7 +1138,7 @@ __global__ void Marlin(
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(!has_zp || group_blocks != 0);
|
||||
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
@@ -1170,11 +1182,44 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if constexpr (has_zp && is_zp_float) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks != -1) {
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
int cur_k = warp_row * 16;
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
|
||||
sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
@@ -1219,10 +1264,14 @@ __global__ void Marlin(
|
||||
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
|
||||
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
}
|
||||
|
||||
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zpf[k % 2][j], 0);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],
|
||||
@@ -1235,10 +1284,14 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b1
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
|
||||
sub_zp<scalar_t>(frag_b1, frag_zpf[k % 2][j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||
@@ -1510,7 +1563,7 @@ __global__ void Marlin(
|
||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks == -1) {
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_zp_to_shared();
|
||||
}
|
||||
@@ -1697,23 +1750,27 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \
|
||||
IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
if constexpr (!IS_ZP_FLOAT || std::is_same<scalar_t, half>::value) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
} \
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@@ -1905,51 +1962,96 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
}
|
||||
|
||||
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \
|
||||
false)
|
||||
|
||||
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
@@ -1958,7 +2060,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce) {
|
||||
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
@@ -2111,6 +2213,11 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
|
||||
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
|
||||
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
|
||||
|
||||
HQQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
||||
HQQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
||||
HQQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
||||
HQQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
@@ -2135,7 +2242,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
vllm::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
@@ -2148,6 +2255,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
@@ -2257,12 +2370,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(0),
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
@@ -2282,7 +2405,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
@@ -2291,7 +2414,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user