[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)
This commit is contained in:
committed by
GitHub
parent
fad5576c58
commit
75acdaa4b6
@@ -59,14 +59,16 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
@@ -532,16 +534,18 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
@@ -595,6 +599,8 @@ __global__ void Marlin(
|
||||
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
||||
// top
|
||||
|
||||
int par_id = 0;
|
||||
|
||||
// We can easily implement parallel problem execution by just remapping
|
||||
// indices and advancing global pointers
|
||||
if (slice_col_par >= n_tiles) {
|
||||
@@ -602,6 +608,7 @@ __global__ void Marlin(
|
||||
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
||||
locks += (slice_col_par / n_tiles) * n_tiles;
|
||||
slice_col = slice_col_par % n_tiles;
|
||||
par_id = slice_col_par / n_tiles;
|
||||
}
|
||||
|
||||
// Compute all information about the current slice which is required for
|
||||
@@ -632,6 +639,7 @@ __global__ void Marlin(
|
||||
C += 16 * thread_m_blocks * prob_n / 8;
|
||||
locks += n_tiles;
|
||||
slice_col = 0;
|
||||
par_id++;
|
||||
}
|
||||
};
|
||||
init_slice();
|
||||
@@ -1321,7 +1329,7 @@ __global__ void Marlin(
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
// results in FP16 (but still reduce with FP32 compute).
|
||||
@@ -1382,6 +1390,53 @@ __global__ void Marlin(
|
||||
}
|
||||
};
|
||||
|
||||
// Globally reduce over threadblocks that compute the same column block.
|
||||
// We use a tmp C buffer to reduce in full fp32 precision.
|
||||
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
|
||||
constexpr int tb_m = thread_m_blocks * 16;
|
||||
constexpr int tb_n = thread_n_blocks * 16;
|
||||
|
||||
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
|
||||
|
||||
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
||||
bool is_th_active = threadIdx.x < active_threads;
|
||||
|
||||
int par_offset = c_size * n_tiles * par_id;
|
||||
int slice_offset = c_size * slice_col;
|
||||
|
||||
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
|
||||
constexpr int th_size = num_floats * sizeof(float) / 16;
|
||||
|
||||
int c_cur_offset = par_offset + slice_offset;
|
||||
|
||||
if (!is_th_active) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!first) {
|
||||
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
sh[threadIdx.x] =
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
||||
|
||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
||||
#pragma unroll
|
||||
for (int f = 0; f < 4; f++) {
|
||||
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!last) {
|
||||
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
@@ -1606,7 +1661,11 @@ __global__ void Marlin(
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
barrier_acquire(&locks[slice_col], slice_idx);
|
||||
global_reduce(slice_idx == 0, last);
|
||||
if (use_fp32_reduce) {
|
||||
global_reduce_fp32(slice_idx == 0, last);
|
||||
} else {
|
||||
global_reduce_fp16(slice_idx == 0, last);
|
||||
}
|
||||
barrier_release(&locks[slice_col], last);
|
||||
}
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
@@ -1661,8 +1720,8 @@ __global__ void Marlin(
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
|
||||
prob_m, prob_n, prob_k, locks); \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
||||
return true;
|
||||
}
|
||||
|
||||
int determine_reduce_max_m(int prob_m, int max_par) {
|
||||
constexpr int tile_m_size = 16;
|
||||
|
||||
if (prob_m <= tile_m_size) {
|
||||
return tile_m_size;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 2) {
|
||||
return tile_m_size * 2;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 3) {
|
||||
return tile_m_size * 3;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 4) {
|
||||
return tile_m_size * 4;
|
||||
|
||||
} else {
|
||||
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
|
||||
return tile_m_size * 4 * cur_par;
|
||||
}
|
||||
}
|
||||
|
||||
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full,
|
||||
@@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace, int num_bits,
|
||||
bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int dev,
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
|
||||
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
int num_bits, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, int num_groups, int group_size, int dev,
|
||||
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||
int max_par) {
|
||||
int max_par, bool use_fp32_reduce) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
@@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
@@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
||||
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
||||
int reduce_n = size_n;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (!use_fp32_reduce) {
|
||||
reduce_max_m = 0;
|
||||
reduce_n = 0;
|
||||
}
|
||||
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
@@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
marlin::marlin_mm_f16i4<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user