[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -538,6 +538,7 @@ __global__ void Marlin(
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
|
||||
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||
i++) {
|
||||
if (c_gl_wr < c_gl_wr_end) {
|
||||
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||
if (use_atomic_add && slice_count > 1) {
|
||||
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);
|
||||
scalar_t2* sh_red_half2 =
|
||||
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
|
||||
#pragma unroll
|
||||
for (int a = 0; a < 4; a++) {
|
||||
atomicAdd(&C_half2[a], sh_red_half2[a]);
|
||||
}
|
||||
} else {
|
||||
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||
}
|
||||
c_gl_wr += c_gl_wr_delta;
|
||||
c_sh_rd += c_sh_rd_delta;
|
||||
}
|
||||
@@ -1644,7 +1655,7 @@ __global__ void Marlin(
|
||||
}
|
||||
cp_async_fence();
|
||||
} else {
|
||||
if (last) {
|
||||
if (last || use_atomic_add) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
}
|
||||
@@ -1664,7 +1675,7 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
} else {
|
||||
if (last) {
|
||||
if (last || use_atomic_add) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
if (slice_count > 1 && !use_atomic_add) {
|
||||
// only globally reduce if there is more than one block in a slice
|
||||
barrier_acquire(&locks[slice_col], slice_idx);
|
||||
if (use_fp32_reduce) {
|
||||
global_reduce_fp32(slice_idx == 0, last);
|
||||
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
|
||||
}
|
||||
barrier_release(&locks[slice_col], last);
|
||||
}
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actuallywrites the result
|
||||
write_result();
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
|
||||
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
|
||||
use_fp32_reduce); \
|
||||
} \
|
||||
}
|
||||
|
||||
@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
||||
int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_k_full, bool has_zp, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
if (has_zp) {
|
||||
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
||||
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
||||
torch::Tensor c;
|
||||
if (use_atomic_add) {
|
||||
c = torch::zeros({size_m, size_n}, options);
|
||||
} else {
|
||||
c = torch::empty({size_m, size_n}, options);
|
||||
}
|
||||
|
||||
torch::Tensor a_tmp;
|
||||
bool has_act_order = g_idx.size(0) != 0;
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m, size_k}, options);
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
||||
int reduce_n = size_n;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (!use_fp32_reduce) {
|
||||
if (use_fp32_reduce) {
|
||||
c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
} else {
|
||||
reduce_max_m = 0;
|
||||
reduce_n = 0;
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
bool has_act_order = g_idx.size(0) != 0;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||
use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||
use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user