Merge branch 'main' into wye-refactor-quant-folder
This commit is contained in:
@@ -470,11 +470,12 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2>
|
||||
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
|
||||
int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
@@ -486,8 +487,8 @@ __device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
||||
nv_bfloat162* frag_b) {
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
@@ -502,6 +503,20 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
|
||||
// but we assume that such a extreme value would not occur in real models.
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
q <<= 7;
|
||||
int Out2 = q & 0x7F807F80;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@@ -20,6 +20,7 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
@@ -78,7 +79,8 @@ def generate_new_kernels():
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
@@ -97,10 +99,23 @@ def generate_new_kernels():
|
||||
# 4bit quantization and fp16
|
||||
is_zp_float_list.append(True)
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
for is_zp_float in is_zp_float_list:
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
|
||||
@@ -48,7 +48,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@@ -187,7 +188,12 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8);
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
@@ -202,8 +208,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
||||
sh_zp_size + sh_g_idx_size;
|
||||
int total_size =
|
||||
tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
@@ -237,20 +243,25 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
@@ -315,22 +326,39 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
@@ -384,7 +412,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
@@ -396,6 +424,11 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
}
|
||||
FZP_GET_IF(vllm::kU4)
|
||||
}
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
@@ -453,12 +486,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k_init,
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
|
||||
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
@@ -503,6 +536,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
@@ -623,8 +657,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
|
||||
g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
|
||||
@@ -638,7 +673,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
@@ -785,12 +821,24 @@ torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
bool has_bias = b_bias_or_none.has_value();
|
||||
torch::Tensor b_bias;
|
||||
if (has_bias) {
|
||||
b_bias = b_bias_or_none.value();
|
||||
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
|
||||
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
|
||||
} else {
|
||||
b_bias = torch::empty({0}, options);
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
@@ -857,34 +905,50 @@ torch::Tensor gptq_marlin_gemm(
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
|
||||
@@ -10,15 +10,18 @@
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||
int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
@@ -39,6 +39,7 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@@ -271,6 +272,7 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@@ -290,6 +292,7 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ b_bias_ptr,
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
@@ -297,12 +300,13 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool has_bias,
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
int max_shared_mem) {
|
||||
@@ -326,18 +330,29 @@ __global__ void Marlin(
|
||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
static_assert(s_type == vllm::kFloat16);
|
||||
}
|
||||
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
w_type == vllm::kFE4M3fn ||
|
||||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
// NVFP4 format requires global scale
|
||||
uint16_t val = scale2_ptr[0];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
@@ -589,7 +604,7 @@ __global__ void Marlin(
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
@@ -602,6 +617,18 @@ __global__ void Marlin(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
int bias_sh_rd;
|
||||
if constexpr (m_block_size_8) {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
} else {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
}
|
||||
|
||||
int bias_sh_wr = threadIdx.x;
|
||||
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
@@ -670,7 +697,19 @@ __global__ void Marlin(
|
||||
constexpr int sh_b_size = stages * b_sh_stage;
|
||||
int4* sh_b = sh;
|
||||
int4* sh_red = sh;
|
||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
|
||||
constexpr int sh_size_b_red_min =
|
||||
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_size_b_red_max =
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
|
||||
constexpr int sh_b_red_bias_size =
|
||||
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
|
||||
? sh_size_b_red_max
|
||||
: (sh_size_b_red_min + sh_bias_size);
|
||||
|
||||
int4* sh_bias = sh + sh_size_b_red_min;
|
||||
int4* sh_g_idx = sh + sh_b_red_bias_size;
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
@@ -680,15 +719,13 @@ __global__ void Marlin(
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
// constexpr int shm_size_used =
|
||||
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
||||
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS frag_bias[2][4];
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
@@ -923,10 +960,15 @@ __global__ void Marlin(
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
|
||||
k % 2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1139,9 +1181,9 @@ __global__ void Marlin(
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
@@ -1411,7 +1453,7 @@ __global__ void Marlin(
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
auto write_result = [&]() {
|
||||
auto write_result = [&](bool last) {
|
||||
int c_gl_stride = prob_n / 8;
|
||||
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||
@@ -1438,7 +1480,7 @@ __global__ void Marlin(
|
||||
int c_gl_wr_end = c_gl_stride * prob_m;
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||
scalar_t2 res =
|
||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
|
||||
@@ -1447,12 +1489,25 @@ __global__ void Marlin(
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
scalar_t2 tmp_scale = s[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_scale = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hmul2(res, tmp_scale);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
if (has_bias && last) {
|
||||
scalar_t2 tmp_bias = b_bias[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_bias = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hadd2(res, tmp_bias);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
@@ -1470,19 +1525,25 @@ __global__ void Marlin(
|
||||
if constexpr (m_block_size_8) {
|
||||
int wr = c_sh_wr + 16 * j;
|
||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
} else {
|
||||
int wr = c_sh_wr + 8 * j;
|
||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||
@@ -1622,6 +1683,14 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
|
||||
if (has_bias && last) {
|
||||
__syncthreads();
|
||||
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
|
||||
threadIdx.x < 16 * thread_n_blocks / 8);
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
@@ -1684,11 +1753,20 @@ __global__ void Marlin(
|
||||
}
|
||||
barrier_release(&locks[locks_off], last);
|
||||
}
|
||||
|
||||
if (has_bias && last) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
|
||||
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
||||
wait_negative_and_add(&locks[locks_off]);
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actually writes the result
|
||||
write_result();
|
||||
write_result(last);
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
slice_col++;
|
||||
@@ -1706,6 +1784,7 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
|
||||
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
// Update slice k/n for scales loading
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start = tb_k * slice_row;
|
||||
|
||||
Reference in New Issue
Block a user