[Kernel] fp4 marlin kernel (#17687)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -292,9 +292,11 @@ __global__ void Marlin(
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
// only)
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
@@ -325,6 +327,21 @@ __global__ void Marlin(
|
||||
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[0];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
|
||||
constexpr bool has_act_order = group_blocks == 0;
|
||||
constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
|
||||
|
||||
@@ -481,7 +498,7 @@ __global__ void Marlin(
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1)
|
||||
: 1;
|
||||
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
|
||||
int s_gl_rd_delta = s_gl_stride;
|
||||
@@ -540,7 +557,8 @@ __global__ void Marlin(
|
||||
if constexpr (group_blocks == -1) {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) /
|
||||
(w_type == vllm::kFE2M1f ? 2 : 1) +
|
||||
s_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
@@ -564,10 +582,20 @@ __global__ void Marlin(
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
int s_sh_rd;
|
||||
if constexpr (group_blocks != -1)
|
||||
if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) {
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
else if constexpr (group_blocks == -1 &&
|
||||
(m_block_size_8 || (has_zp && !dequant_skip_flop)))
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
else
|
||||
@@ -681,7 +709,7 @@ __global__ void Marlin(
|
||||
sh_first_group_id = first_group_id;
|
||||
sh_num_groups = last_group_id - first_group_id + 1;
|
||||
|
||||
if (sh_num_groups < act_s_max_num_groups) {
|
||||
if (sh_num_groups > act_s_max_num_groups) {
|
||||
sh_num_groups = act_s_max_num_groups;
|
||||
}
|
||||
|
||||
@@ -887,12 +915,19 @@ __global__ void Marlin(
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
int cur_group_id =
|
||||
k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1));
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1065,22 +1100,7 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
|
||||
if constexpr (has_zp && is_zp_float || !has_zp) {
|
||||
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
|
||||
} else {
|
||||
static_assert(has_zp && !is_zp_float);
|
||||
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
|
||||
// If (has_zp && !is_zp_float),
|
||||
// we use not-zp version `dequant` function
|
||||
// to improve numerical accuracy.
|
||||
// Since both weight and zero point are dequanted using this logic,
|
||||
// the final dequanted weight would be correct.
|
||||
if constexpr (w_type_id == vllm::kU4.id()) {
|
||||
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
|
||||
} else if constexpr (w_type_id == vllm::kU8.id()) {
|
||||
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
|
||||
}
|
||||
}
|
||||
dequant<scalar_t2, w_type_id, dequant_skip_flop>(q, frag_b_ptr);
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
@@ -1110,13 +1130,23 @@ __global__ void Marlin(
|
||||
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
|
||||
}
|
||||
}
|
||||
if constexpr (has_zp && is_zp_float) {
|
||||
if constexpr (!dequant_skip_flop && has_zp && is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
reinterpret_cast<int4*>(&frag_zp)[0] =
|
||||
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@@ -1125,7 +1155,10 @@ __global__ void Marlin(
|
||||
FragB frag_b1;
|
||||
int b_quant_0, b_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
if constexpr (w_type_id == vllm::kFE2M1f.id()) {
|
||||
b_quant_1 = frag_b_quant[k2][0][j];
|
||||
b_quant_0 = b_quant_1 << 8;
|
||||
} else if constexpr (w_type.size_bits() == 4) {
|
||||
b_quant_0 = frag_b_quant[k2][0][j];
|
||||
b_quant_1 = b_quant_0 >> 8;
|
||||
} else {
|
||||
@@ -1138,6 +1171,11 @@ __global__ void Marlin(
|
||||
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
|
||||
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
|
||||
|
||||
if constexpr (dequant_skip_flop && has_zp && !is_zp_float) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
if constexpr (has_act_order) {
|
||||
static_assert(group_blocks != -1);
|
||||
@@ -1145,7 +1183,8 @@ __global__ void Marlin(
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
|
||||
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
|
||||
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float &&
|
||||
group_blocks == -1) {
|
||||
int idx = (threadIdx.x / 4) % 2;
|
||||
scalar_t2 s2 = Dtype::nums2num2(
|
||||
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
|
||||
@@ -1153,7 +1192,7 @@ __global__ void Marlin(
|
||||
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
|
||||
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
|
||||
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
|
||||
} else if constexpr (has_zp && group_blocks != -1) {
|
||||
} else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) {
|
||||
if (is_new_zp)
|
||||
frag_zp[j] = __hmul2(frag_zp[j],
|
||||
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
|
||||
@@ -1408,10 +1447,15 @@ __global__ void Marlin(
|
||||
// For per-column quantization we finally apply the scale here (only for
|
||||
// 4-bit)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 && !has_zp) {
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
|
||||
@@ -1488,7 +1532,9 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_col_zp_to_shared();
|
||||
fetch_col_scale_to_shared();
|
||||
if constexpr (!dequant_skip_flop) {
|
||||
fetch_col_scale_to_shared();
|
||||
}
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
@@ -1563,7 +1609,8 @@ __global__ void Marlin(
|
||||
bool last = slice_idx == slice_count - 1;
|
||||
// For per-column scales, we only fetch them here in the final step before
|
||||
// write-out
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
@@ -1573,7 +1620,8 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@@ -1597,7 +1645,8 @@ __global__ void Marlin(
|
||||
// that converts the fp32 results to fp16 (so that we avoid possible
|
||||
// overflow in fp16)
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 8 && !has_zp) {
|
||||
w_type.size_bits() == 8 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks; i++) {
|
||||
|
||||
Reference in New Issue
Block a user