[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -53,7 +53,7 @@ torch::Tensor gptq_marlin_gemm(
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
@@ -243,204 +243,29 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
// this is the most common cases
|
||||
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
|
||||
// FZP: cases for float-zero-point (is_zp_float = true)
|
||||
// ACT: cases for act order case (group_blocks == 0)
|
||||
// FP4: cases for nvfp4(e2m1) (group_blocks == 1)
|
||||
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define COMMON_GET_IF(W_TYPE) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
COMMON_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
|
||||
|
||||
#define FZP_GET_IF(W_TYPE) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
FZP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
|
||||
|
||||
#define ACT_GET_IF(W_TYPE) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
ACT_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
|
||||
COMMON_GET_IF(vllm::kU4)
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
|
||||
if (std::is_same<scalar_t, half>::value) {
|
||||
if (false) {
|
||||
}
|
||||
FZP_GET_IF(vllm::kU4)
|
||||
}
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
#include "kernel_selector.h"
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem,
|
||||
int sms) {
|
||||
exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8,
|
||||
int num_bits, int group_size, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, bool is_zp_float, int max_shared_mem, int sms) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
@@ -455,7 +280,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
is_zp_float, max_shared_mem - 512)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -468,10 +293,11 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
auto kernel =
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
@@ -485,28 +311,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
|
||||
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
|
||||
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
|
||||
void* perm, void* a_tmp, int prob_m, int prob_n, int prob_k,
|
||||
int lda, void* workspace, vllm::ScalarType const& a_type,
|
||||
vllm::ScalarType const& b_type, vllm::ScalarType const& c_type,
|
||||
vllm::ScalarType const& s_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
|
||||
q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f,
|
||||
"q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
@@ -531,19 +345,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
int num_bits = b_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
@@ -568,6 +384,21 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
int major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
major_capability * 10 + minor_capability == 120,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
|
||||
"Marlin W4A16 on other devices).");
|
||||
}
|
||||
|
||||
int max_par = 16;
|
||||
if (prob_n <= 4096) max_par = 16 * 8;
|
||||
int max_shared_mem_new = max_shared_mem;
|
||||
@@ -583,7 +414,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
int thread_n = thread_n_init;
|
||||
|
||||
int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks);
|
||||
int m_block_size_8 = prob_m_split <= 8;
|
||||
int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16;
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
@@ -597,11 +428,25 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, sms);
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k,
|
||||
thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem, sms);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
if (thread_tfg.thread_n != -1) {
|
||||
if (prob_n / thread_tfg.thread_n *
|
||||
div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <=
|
||||
sms) {
|
||||
if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split,
|
||||
prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem_new)) {
|
||||
thread_tfg = {128, 64, 128};
|
||||
exec_cfg = {1, thread_tfg};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) {
|
||||
max_thread_m_blocks--;
|
||||
continue;
|
||||
@@ -632,10 +477,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem_new = ", max_shared_mem_new);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks, num_threads,
|
||||
is_zp_float);
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
@@ -657,13 +502,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr,
|
||||
g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
|
||||
A_ptr += prob_m_split * (lda / 8);
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8));
|
||||
a_s_ptr += prob_m_split;
|
||||
C_ptr += prob_m_split * (prob_n / 8);
|
||||
rest_m -= prob_m_split;
|
||||
}
|
||||
@@ -675,15 +522,73 @@ torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& a_scales_or_none,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
|
||||
|
||||
auto c_dtype = a.dtype();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
a_type_id = vllm::kFloat16.id();
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
a_type_id = vllm::kBFloat16.id();
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_dtype = b_scales.dtype();
|
||||
if (b_scales.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
|
||||
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
|
||||
torch::Tensor c = c_or_none.value();
|
||||
c_dtype = c.dtype();
|
||||
|
||||
if (c.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported c dtype");
|
||||
}
|
||||
}
|
||||
|
||||
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
a_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::Char) {
|
||||
a_type_id = vllm::kS8.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported `a` scalar_type");
|
||||
}
|
||||
}
|
||||
|
||||
s_type_id = c_type_id;
|
||||
if (b_type_id == vllm::kFE2M1f.id()) {
|
||||
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
s_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
|
||||
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
|
||||
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
|
||||
int pack_factor = 32 / b_type.size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@@ -721,6 +626,21 @@ torch::Tensor gptq_marlin_gemm(
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
torch::Tensor a_scales;
|
||||
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
|
||||
if (a_scales_or_none.has_value()) {
|
||||
a_scales = a_scales_or_none.value();
|
||||
TORCH_CHECK(a_type.size_bits() == 8,
|
||||
"a_scales can only be used for 8bit activation.");
|
||||
} else {
|
||||
a_scales = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(a_type.size_bits() != 8,
|
||||
"the a_scales parameter must be passed for 8bit activation.");
|
||||
}
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
@@ -733,7 +653,6 @@ torch::Tensor gptq_marlin_gemm(
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
@@ -750,8 +669,6 @@ torch::Tensor gptq_marlin_gemm(
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce) {
|
||||
int max_m_block_size = (size_m + 16 - 1) / 16 * 16;
|
||||
max_m_block_size = min(max_m_block_size, 64);
|
||||
@@ -821,11 +738,11 @@ torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
@@ -852,15 +769,15 @@ torch::Tensor gptq_marlin_gemm(
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str());
|
||||
b_type == vllm::kU4 || b_type == vllm::kU8,
|
||||
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
|
||||
b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f,
|
||||
"b_q_type must be uint4b8, uint8b128, float8_e4m3fn or "
|
||||
"float4_e2m1f when "
|
||||
"has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
|
||||
"b_type must be uint4b8, uint8b128, int4, int8, "
|
||||
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
|
||||
b_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
@@ -902,59 +819,27 @@ torch::Tensor gptq_marlin_gemm(
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
"scalar type of a must be the same with c for 16 bit activation");
|
||||
}
|
||||
|
||||
marlin::marlin_mm(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
|
||||
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
|
||||
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), a_type, b_type, c_type, s_type, has_bias,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user