[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-04-01 18:41:42 +02:00
committed by GitHub
parent dc0428ebb8
commit db5d0719e1
15 changed files with 481 additions and 129 deletions

View File

@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -343,6 +343,8 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
static_assert(group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -357,9 +359,10 @@ __global__ void Marlin(
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || b_type == vllm::kFE4M3fn ||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -373,7 +376,7 @@ __global__ void Marlin(
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride =
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -692,9 +695,8 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
}
}
if constexpr (b_type == vllm::kFE2M1f) {
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -327,6 +327,9 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (s_type == vllm::kFE8M0fnu) {
// MXFP8: FP8 weights with e8m0 microscaling block scales
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -334,6 +337,7 @@ __global__ void Marlin(
}
constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
if constexpr (!is_a_8bit) {
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
@@ -343,7 +347,7 @@ __global__ void Marlin(
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || b_type == vllm::kFE4M3fn ||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -555,9 +559,8 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -997,7 +1000,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
}
}
if constexpr (b_type == vllm::kFE2M1f) {
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];