[Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm (#26092)

This commit is contained in:
Xiangyu Li
2025-10-24 11:26:13 +08:00
committed by GitHub
parent 1f9460c4c1
commit 5cc6bddb6e
8 changed files with 295 additions and 98 deletions

View File

@@ -307,7 +307,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);
bool use_exllama, bool use_v2_format, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);

View File

@@ -185,7 +185,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*,
const uint32_t*, const half*,
half*, const int, const int,
const int, const int,
const int*);
const bool, const int*);
template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_4bit_kernel(
@@ -193,12 +193,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
@@ -256,10 +259,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
// Column result
float block_c[m_count][4] = {};
@@ -272,10 +275,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
}
#pragma unroll
@@ -329,12 +332,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
@@ -409,10 +415,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
#pragma unroll
for (int m = 0; m < m_count; m++) {
@@ -448,12 +454,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
@@ -534,13 +543,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1);
size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1);
size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1);
size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1);
size_n, zeros[3] + zero_offset);
#pragma unroll
for (int m = 0; m < m_count; m++) {
@@ -574,12 +583,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) {
const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x;
// Block
@@ -658,13 +670,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1);
zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1);
zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1);
zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1);
zeros[3] + zero_offset);
for (int m = 0; m < m_count; m++) {
block_c[m][0] =
@@ -730,7 +742,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm,
half* c, int size_m, int size_n, int size_k,
int m_count, int groups, int bit) {
int m_count, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
@@ -743,20 +756,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(a, b_q_weight, b_gptq_qzeros,
b_gptq_scales, c, size_m, size_n,
size_k, groups, b_q_perm);
kernel<<<gridDim, blockDim, 0, stream>>>(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
groups, use_v2_format, b_q_perm);
}
__global__ void reconstruct_exllama_8bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@@ -812,13 +828,13 @@ __global__ void reconstruct_exllama_8bit_kernel(
half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1);
zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1);
zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1);
zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1);
zeros[3] + zero_offset);
// half* dqh = (half*)dq;
if (b_q_perm) {
@@ -849,11 +865,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@@ -888,10 +907,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
__syncthreads();
@@ -904,10 +923,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
}
for (int p = 0; p < 4; p++) {
@@ -954,11 +973,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@@ -1016,13 +1038,13 @@ __global__ void reconstruct_exllama_3bit_kernel(
half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1);
size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1);
size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1);
size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1);
size_n, zeros[3] + zero_offset);
if (b_q_perm) {
for (int j = 0; j < 16; j++) {
@@ -1052,11 +1074,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) {
const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@@ -1108,10 +1133,10 @@ __global__ void reconstruct_exllama_2bit_kernel(
int4 load_int4 = *b_ptr4;
half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1);
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
b_ptr += size_n;
// half* dqh = (half*)dq;
@@ -1143,7 +1168,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm,
half* out, int height, int width, int groups,
int bit) {
bool use_v2_format, int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
@@ -1162,14 +1187,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
out);
use_v2_format, out);
}
__global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) {
int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 8;
int vec_height = height * 4;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
@@ -1179,6 +1204,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
@@ -1223,10 +1251,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
half2 zero = __halves2half2(
__hmul(scale_f,
__int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
1)),
__hmul(scale_f2,
__int2half_rn(
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)));
zero_offset)),
__hmul(
scale_f2,
__int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) -
zero_offset)));
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
@@ -1268,7 +1297,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) {
int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 4;
int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2;
@@ -1278,6 +1307,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) {
@@ -1312,12 +1344,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2(
__hmul(scale_f,
__int2half_rn(
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)),
__hmul(scale_f2,
__int2half_rn(
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1)));
__hmul(scale_f, __int2half_rn(
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) -
zero_offset)),
__hmul(
scale_f2,
__int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) -
zero_offset)));
scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero;
}
@@ -1355,7 +1388,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx,
half* c, int size_m, int size_n, int size_k,
int bit) {
bool use_v2_format, int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
@@ -1372,17 +1405,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
size_m, size_k / 32 * bit, size_n);
size_m, size_k / 32 * bit, size_n, use_v2_format);
}
template <class T, int bit>
__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros,
const int* __restrict__ g_idx,
const int height, const int width,
const int group,
half* __restrict__ out) {
__global__ void reconstruct_gptq_kernel(
const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int height, const int width, const int group,
const bool use_v2_format, half* __restrict__ out) {
// Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
@@ -1395,6 +1426,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
MatrixView_half w_scales_(w_scales, group, width);
T w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column);
@@ -1402,7 +1436,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
for (int s = 0; s < 32; s += bit) {
int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
half w_item =
__hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero),
w_scale);
@@ -1415,7 +1449,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int height, const int width, const int group,
half* __restrict__ out) {
const bool use_v2_format, half* __restrict__ out) {
// Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto row = blockIdx.y * 32;
@@ -1427,6 +1461,9 @@ __global__ void reconstruct_gptq_3bit_kernel(
MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
@@ -1436,7 +1473,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
for (int i = 0; i < 32; i += 1) {
int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1;
uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
int w_item;
if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
@@ -1456,7 +1493,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, half* out,
int height, int width, int groups, int bit) {
int height, int width, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1;
@@ -1476,7 +1514,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
b_gptq_qzeros, b_g_idx, height,
width, groups, out);
width, groups, use_v2_format, out);
}
void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
@@ -1484,7 +1522,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx,
half* c, half* temp_dq, int size_m, int size_n,
int size_k, int groups, bool use_exllama, int bit) {
int size_k, int groups, bool use_exllama,
bool use_v2_format, int bit) {
bool use_reconstruct;
if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
@@ -1498,10 +1537,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
// Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit);
temp_dq, size_k, size_n, groups, use_v2_format, bit);
} else {
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit);
temp_dq, size_k, size_n, groups, use_v2_format, bit);
}
const half alpha = __float2half(1.0f);
@@ -1517,18 +1556,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
if (max_chunks) {
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, c, last_chunk, size_n, size_k,
BLOCK_M_SIZE_MAX, groups, bit);
BLOCK_M_SIZE_MAX, groups, use_v2_format, bit);
}
if (last_chunk_size) {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight,
b_gptq_qzeros, b_gptq_scales, b_g_idx,
c + last_chunk * size_n, last_chunk_size,
size_n, size_k, last_chunk_size, groups, bit);
gemm_half_q_half_cuda_part(
a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k,
last_chunk_size, groups, use_v2_format, bit);
}
} else {
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k, bit);
c, size_m, size_n, size_k, use_v2_format, bit);
}
}
@@ -1815,7 +1854,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit) {
bool use_exllama, bool use_v2_format, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
@@ -1833,7 +1872,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
c.size(1), // n
a.size(1), // k
b_gptq_qzeros.size(0), // group number
use_exllama, bit);
use_exllama, use_v2_format, bit);
return c;
}

View File

@@ -557,7 +557,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
"-> Tensor",
{stride_tag});
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);