[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE4M3fn",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
|
||||
@@ -343,6 +343,8 @@ __global__ void Marlin(
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
|
||||
static_assert(group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
@@ -357,9 +359,10 @@ __global__ void Marlin(
|
||||
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
|
||||
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
is_a_8bit || b_type == vllm::kFE4M3fn ||
|
||||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
|
||||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
@@ -373,7 +376,7 @@ __global__ void Marlin(
|
||||
const int group_size =
|
||||
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
|
||||
const int scales_expert_stride =
|
||||
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
@@ -692,9 +695,8 @@ __global__ void Marlin(
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
// Scale sizes/strides without act_order
|
||||
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
constexpr int s_sh_stride =
|
||||
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
} else if (group_blocks >= b_sh_wr_iters) {
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
} else {
|
||||
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
|
||||
@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
} else if (b_type_id == vllm::kFE4M3fn.id() &&
|
||||
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
|
||||
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# MXFP8
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE4M3fn",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
|
||||
@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
} else if (b_type_id == vllm::kFE4M3fn.id() &&
|
||||
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
|
||||
@@ -327,6 +327,9 @@ __global__ void Marlin(
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (s_type == vllm::kFE8M0fnu) {
|
||||
// MXFP8: FP8 weights with e8m0 microscaling block scales
|
||||
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
@@ -334,6 +337,7 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
constexpr bool is_a_8bit = a_type.size_bits() == 8;
|
||||
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
|
||||
if constexpr (!is_a_8bit) {
|
||||
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
|
||||
}
|
||||
@@ -343,7 +347,7 @@ __global__ void Marlin(
|
||||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
is_a_8bit || b_type == vllm::kFE4M3fn ||
|
||||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
|
||||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
|
||||
@@ -555,9 +559,8 @@ __global__ void Marlin(
|
||||
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
|
||||
|
||||
// Scale sizes/strides without act_order
|
||||
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
constexpr int s_sh_stride =
|
||||
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
|
||||
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
|
||||
constexpr int s_tb_groups =
|
||||
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
|
||||
? thread_k_blocks / group_blocks
|
||||
@@ -997,7 +1000,7 @@ __global__ void Marlin(
|
||||
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else {
|
||||
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
} else if (group_blocks >= b_sh_wr_iters) {
|
||||
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
|
||||
if constexpr (!is_8bit_scale) {
|
||||
reinterpret_cast<int4*>(&frag_s[1])[0] =
|
||||
reinterpret_cast<int4*>(&frag_s[0])[0];
|
||||
} else {
|
||||
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (b_type == vllm::kFE2M1f) {
|
||||
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user