[Misc] Use scalar type to dispatch to different gptq_marlin kernels (#7323)
This commit is contained in:
@@ -42,8 +42,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
int4* __restrict__ out_int4_ptr, int size_m,
|
||||
int size_k, int block_rows) {}
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int num_bits, // number of bits used for weights
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@@ -151,20 +151,21 @@ __device__ inline uint32_t prmt(uint32_t a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||
// values. We mostly follow the strategy in the link below, with some small
|
||||
// changes:
|
||||
template <typename scalar_t, vllm::ScalarTypeId w_type_id>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
|
||||
|
||||
//
|
||||
// Efficiently dequantize 4bit values packed in an int32 value into a full
|
||||
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
|
||||
// with some small changes:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
//
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
||||
__device__ inline typename ScalarType<half>::FragB
|
||||
dequant<half, vllm::kU4B8.id()>(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
@@ -187,7 +188,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_4bit<nv_bfloat16>(int q) {
|
||||
dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
@@ -210,19 +211,64 @@ dequant_4bit<nv_bfloat16>(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB
|
||||
dequant<half, vllm::kU4.id()>(int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
//
|
||||
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
|
||||
// bf16 Reference:
|
||||
// - FP16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
|
||||
// - BF16:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
//
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
||||
__device__ inline typename ScalarType<half>::FragB
|
||||
dequant<half, vllm::kU8B128.id()>(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
@@ -242,7 +288,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_8bit<nv_bfloat16>(int q) {
|
||||
dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
@@ -269,68 +315,9 @@ dequant_8bit<nv_bfloat16>(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Zero-point dequantizers
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
|
||||
int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_4bit_zp<nv_bfloat16>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
||||
int q) {
|
||||
__device__ inline typename ScalarType<half>::FragB
|
||||
dequant<half, vllm::kU8.id()>(int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
@@ -350,7 +337,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_8bit_zp<nv_bfloat16>(int q) {
|
||||
dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
@@ -517,8 +504,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int num_bits, // number of bits used for weights
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@@ -568,7 +555,9 @@ __global__ void Marlin(
|
||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
|
||||
constexpr int pack_factor = 32 / w_type.size_bits();
|
||||
|
||||
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
|
||||
// better partitioning with less reductions
|
||||
@@ -670,7 +659,7 @@ __global__ void Marlin(
|
||||
// B sizes/strides
|
||||
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
|
||||
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
|
||||
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
|
||||
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
|
||||
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
|
||||
|
||||
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
|
||||
@@ -1186,19 +1175,20 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
if constexpr (num_bits == 4) {
|
||||
int zp_quant = frag_qzp[k % 2][0];
|
||||
int zp_quant_shift = zp_quant >> 8;
|
||||
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
|
||||
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
zp_quant_0 = frag_qzp[k % 2][0];
|
||||
zp_quant_1 = zp_quant_0 >> 8;
|
||||
} else {
|
||||
int zp_quant_0 = frag_qzp[k % 2][0];
|
||||
int zp_quant_1 = frag_qzp[k % 2][1];
|
||||
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
|
||||
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
|
||||
static_assert(w_type.size_bits() == 8);
|
||||
zp_quant_0 = frag_qzp[k % 2][0];
|
||||
zp_quant_1 = frag_qzp[k % 2][1];
|
||||
}
|
||||
|
||||
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
|
||||
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
|
||||
|
||||
frag_zp[0] = frag_zp_0[0];
|
||||
frag_zp[1] = frag_zp_0[1];
|
||||
frag_zp[2] = frag_zp_1[0];
|
||||
@@ -1211,33 +1201,21 @@ __global__ void Marlin(
|
||||
for (int j = 0; j < 4; j++) {
|
||||
FragB frag_b0;
|
||||
FragB frag_b1;
|
||||
if constexpr (num_bits == 4) {
|
||||
int b_quant = frag_b_quant[k % 2][0][j];
|
||||
int b_quant_shift = b_quant >> 8;
|
||||
|
||||
if constexpr (has_zp) {
|
||||
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
|
||||
|
||||
} else {
|
||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||
}
|
||||
int b_quant_0, b_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
b_quant_0 = frag_b_quant[k % 2][0][j];
|
||||
b_quant_1 = b_quant_0 >> 8;
|
||||
} else {
|
||||
static_assert(w_type.size_bits() == 8);
|
||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
|
||||
if constexpr (has_zp) {
|
||||
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
|
||||
} else {
|
||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||
}
|
||||
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
}
|
||||
|
||||
frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
|
||||
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
|
||||
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
@@ -1477,7 +1455,8 @@ __global__ void Marlin(
|
||||
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4) {
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
@@ -1605,7 +1584,7 @@ __global__ void Marlin(
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if constexpr (!has_act_order && group_blocks == -1) {
|
||||
if constexpr (num_bits == 8) {
|
||||
if constexpr (w_type.size_bits() == 8) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
}
|
||||
@@ -1622,7 +1601,7 @@ __global__ void Marlin(
|
||||
|
||||
thread_block_reduce();
|
||||
if constexpr (!has_act_order && group_blocks == -1) {
|
||||
if constexpr (num_bits == 8) {
|
||||
if constexpr (w_type.size_bits() == 8) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
@@ -1645,7 +1624,8 @@ __global__ void Marlin(
|
||||
// For 8-bit channelwise, we apply the scale before the global reduction
|
||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||
// overflow in fp16)
|
||||
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 8) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
@@ -1714,20 +1694,19 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS) \
|
||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
@@ -1923,52 +1902,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
return exec_config_t{0, {-1, -1, -1}};
|
||||
}
|
||||
|
||||
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
|
||||
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
@@ -2113,23 +2092,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_CALL_IF(4, 16, 4, 256)
|
||||
GPTQ_CALL_IF(4, 8, 8, 256)
|
||||
GPTQ_CALL_IF(4, 8, 4, 128)
|
||||
GPTQ_CALL_IF(4, 4, 8, 128)
|
||||
GPTQ_CALL_IF(8, 16, 4, 256)
|
||||
GPTQ_CALL_IF(8, 8, 8, 256)
|
||||
GPTQ_CALL_IF(8, 8, 4, 128)
|
||||
GPTQ_CALL_IF(8, 4, 8, 128)
|
||||
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
||||
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
|
||||
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
|
||||
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
|
||||
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
|
||||
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
|
||||
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
|
||||
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
|
||||
|
||||
AWQ_CALL_IF(4, 16, 4, 256)
|
||||
AWQ_CALL_IF(4, 8, 8, 256)
|
||||
AWQ_CALL_IF(4, 8, 4, 128)
|
||||
AWQ_CALL_IF(4, 4, 8, 128)
|
||||
AWQ_CALL_IF(8, 16, 4, 256)
|
||||
AWQ_CALL_IF(8, 8, 8, 256)
|
||||
AWQ_CALL_IF(8, 8, 4, 128)
|
||||
AWQ_CALL_IF(8, 4, 8, 128)
|
||||
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
|
||||
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
|
||||
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
|
||||
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
|
||||
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
|
||||
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
|
||||
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
|
||||
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
|
||||
Reference in New Issue
Block a user