[Misc] Use scalar type to dispatch to different gptq_marlin kernels (#7323)

This commit is contained in:
Lucas Wilkinson
2024-08-12 14:40:13 -04:00
committed by GitHub
parent 1137f343aa
commit 6aa33cb2dd
2 changed files with 332 additions and 218 deletions

View File

@@ -42,8 +42,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {}
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@@ -151,20 +151,21 @@ __device__ inline uint32_t prmt(uint32_t a) {
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
template <typename scalar_t, vllm::ScalarTypeId w_type_id>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU4B8.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@@ -187,7 +188,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit<nv_bfloat16>(int q) {
dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
@@ -210,19 +211,64 @@ dequant_4bit<nv_bfloat16>(int q) {
return frag_b;
}
template <>
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU4.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU8B128.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@@ -242,7 +288,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit<nv_bfloat16>(int q) {
dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
@@ -269,68 +315,9 @@ dequant_8bit<nv_bfloat16>(int q) {
return frag_b;
}
// Zero-point dequantizers
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit_zp<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
int q) {
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU8.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
@@ -350,7 +337,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit_zp<nv_bfloat16>(int q) {
dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
@@ -517,8 +504,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}
}
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
@@ -568,7 +555,9 @@ __global__ void Marlin(
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
constexpr int pack_factor = 32 / num_bits;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr int pack_factor = 32 / w_type.size_bits();
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
@@ -670,7 +659,7 @@ __global__ void Marlin(
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
@@ -1186,19 +1175,20 @@ __global__ void Marlin(
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
if constexpr (num_bits == 4) {
int zp_quant = frag_qzp[k % 2][0];
int zp_quant_shift = zp_quant >> 8;
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
int zp_quant_0 = frag_qzp[k % 2][0];
int zp_quant_1 = frag_qzp[k % 2][1];
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = frag_qzp[k % 2][1];
}
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
@@ -1211,33 +1201,21 @@ __global__ void Marlin(
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
if constexpr (num_bits == 4) {
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
if constexpr (has_zp) {
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
} else {
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
}
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k % 2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
if constexpr (has_zp) {
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
} else {
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
@@ -1477,7 +1455,8 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4) {
res = __hmul2(res, s[0]);
}
@@ -1605,7 +1584,7 @@ __global__ void Marlin(
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) {
if constexpr (w_type.size_bits() == 8) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
@@ -1622,7 +1601,7 @@ __global__ void Marlin(
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) {
if constexpr (w_type.size_bits() == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
@@ -1645,7 +1624,8 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
@@ -1714,20 +1694,19 @@ __global__ void Marlin(
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
@@ -1923,52 +1902,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}};
}
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
@@ -2113,23 +2092,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
if (false) {
}
GPTQ_CALL_IF(4, 16, 4, 256)
GPTQ_CALL_IF(4, 8, 8, 256)
GPTQ_CALL_IF(4, 8, 4, 128)
GPTQ_CALL_IF(4, 4, 8, 128)
GPTQ_CALL_IF(8, 16, 4, 256)
GPTQ_CALL_IF(8, 8, 8, 256)
GPTQ_CALL_IF(8, 8, 4, 128)
GPTQ_CALL_IF(8, 4, 8, 128)
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
AWQ_CALL_IF(4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128)
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,