[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-04-01 18:41:42 +02:00
committed by GitHub
parent dc0428ebb8
commit db5d0719e1
15 changed files with 481 additions and 129 deletions

View File

@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -343,6 +343,8 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (b_type == vllm::kFE4M3fn && s_type == vllm::kFE8M0fnu) {
static_assert(group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -357,9 +359,10 @@ __global__ void Marlin(
constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || b_type == vllm::kFE4M3fn ||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -373,7 +376,7 @@ __global__ void Marlin(
const int group_size =
(!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride =
prob_n * prob_k / group_size / (b_type == vllm::kFE2M1f ? 16 : 8);
prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8);
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
@@ -692,9 +695,8 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
}
}
if constexpr (b_type == vllm::kFE2M1f) {
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# MXFP8
{
"a_type": ["kBFloat16"],
"b_type": "kFE4M3fn",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],

View File

@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
} else if (b_type_id == vllm::kFE4M3fn.id() &&
b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);

View File

@@ -327,6 +327,9 @@ __global__ void Marlin(
if constexpr (b_type == vllm::kFE2M1f) {
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
} else if constexpr (s_type == vllm::kFE8M0fnu) {
// MXFP8: FP8 weights with e8m0 microscaling block scales
static_assert(b_type == vllm::kFE4M3fn && group_blocks == 2);
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
static_assert(s_type == vllm::kBFloat16);
} else if constexpr (std::is_same<scalar_t, half>::value) {
@@ -334,6 +337,7 @@ __global__ void Marlin(
}
constexpr bool is_a_8bit = a_type.size_bits() == 8;
constexpr bool is_8bit_scale = s_type.size_bits() == 8;
if constexpr (!is_a_8bit) {
static_assert(std::is_same<scalar_t, c_scalar_t>::value);
}
@@ -343,7 +347,7 @@ __global__ void Marlin(
b_type == vllm::kU4B8 || b_type == vllm::kU8B128;
// see comments of dequant.h for more details
constexpr bool dequant_skip_flop =
is_a_8bit || b_type == vllm::kFE4M3fn ||
is_a_8bit || (b_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)) ||
b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
has_zp && !is_zp_float && !(b_type == vllm::kU8);
@@ -555,9 +559,8 @@ __global__ void Marlin(
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8);
constexpr int s_sh_stride =
16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8);
int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8);
constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8);
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
@@ -997,7 +1000,7 @@ __global__ void Marlin(
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
} else {
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
}
} else if (group_blocks >= b_sh_wr_iters) {
if constexpr (b_type_id != vllm::kFE2M1f.id()) {
if constexpr (!is_8bit_scale) {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
} else {
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
}
}
if constexpr (b_type == vllm::kFE2M1f) {
if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) {
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];

View File

@@ -151,6 +151,12 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [
"b_type": scalar_types.float4_e2m1f,
"group_blocks": [2],
},
# MXFP8
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float8_e4m3fn,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": [scalar_types.int8],

View File

@@ -23,7 +23,7 @@ from tests.quantization.utils import is_quant_method_supported
from ..utils import check_logprobs_close
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
MOE_MODEL = "Qwen/Qwen3-30B-A3B"
MOE_MODEL = "allenai/OLMoE-1B-7B-0125-Instruct"
# A small dense model (no MoE) to validate the linear-only path.
DENSE_MODEL = "Qwen/Qwen3-0.6B"

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Static,
kMxfp8Static,
kNvfp4Static,
)
from vllm.platforms import current_platform
@@ -582,6 +583,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kMxfp4Static,
kMxfp8Static,
kNvfp4Static,
]
return weight_key in SUPPORTED_W
@@ -609,7 +611,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
@property
def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
if self.quant_config.use_int4_w4a16:
return scalar_types.uint4b8.id
elif self.quant_config.use_mxfp4_w4a16 or self.quant_config.use_nvfp4_w4a16:

View File

@@ -436,13 +436,27 @@ def convert_to_fp8_moe_kernel_format(
elif fp8_backend == Fp8MoeBackend.AITER:
w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2)
elif fp8_backend == Fp8MoeBackend.MARLIN:
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
layer,
w13,
w2,
w13_scale,
w2_scale,
)
weight_block_size = getattr(layer, "weight_block_size", None)
if weight_block_size == [1, 32]:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_mxfp8_moe_layer_for_marlin,
)
w13, w2, w13_scale, w2_scale = prepare_mxfp8_moe_layer_for_marlin(
layer,
w13,
w2,
w13_scale,
w2_scale,
)
else:
w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
layer,
w13,
w2,
w13_scale,
w2_scale,
)
elif fp8_backend in [
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.FLASHINFER_TRTLLM,

View File

@@ -15,14 +15,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
logger = init_logger(__name__)
_SUPPORTED_BACKENDS: frozenset[Fp8MoeBackend] = frozenset(
{
Fp8MoeBackend.FLASHINFER_TRTLLM,
}
_SUPPORTED_BACKENDS = (
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.MARLIN,
)
_BACKEND_NAME_MAP: dict[str, Fp8MoeBackend] = {
"flashinfer_trtllm": Fp8MoeBackend.FLASHINFER_TRTLLM,
"marlin": Fp8MoeBackend.MARLIN,
}
@@ -81,7 +81,11 @@ def select_mxfp8_moe_backend(
# Auto-select: pick the first supported backend.
for backend in _SUPPORTED_BACKENDS:
try:
experts_cls = _select_kernel_cls(backend, config)
except ValueError:
continue
logger.info_once("Using '%s' MxFp8 MoE backend.", backend.value)
return backend, _select_kernel_cls(backend, config)
return backend, experts_cls
raise ValueError("No MXFP8 MoE backends available.")

View File

@@ -67,10 +67,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE,
MXFP8_VALUE_DTYPE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
mxfp8_e4m3_quantize,
swizzle_mxfp8_scale,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
apply_nvfp4_linear,
@@ -1499,8 +1497,8 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
@classmethod
def get_min_capability(cls) -> int:
# MXFP8 hardware acceleration requires Blackwell (SM100) or newer
return 100
# Marlin kernel supports MXFP8 on SM80+
return 80
@classmethod
def override_quantization_method(
@@ -1555,9 +1553,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
"Dynamic quantization is not supported."
)
self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS
self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend)
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value)
self.mxfp8_linear_op = Mxfp8LinearOp()
def create_weights(
self,
@@ -1615,36 +1611,6 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight_scale", weight_scale)
def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None:
"""Not swizzled - MXFP8 GEMM emulation"""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None:
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
weight = layer.weight.data # [N, K]
N, K = weight.shape
# 2D weight scale
weight_scale = layer.weight_scale.data
# Swizzle the weight scales
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale_2d = weight_scale[:N, :scale_k].contiguous()
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(
weight_scale_swizzled.contiguous(), requires_grad=False
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Validate weight tensor
if layer.weight.ndim != 2:
@@ -1669,14 +1635,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
f" got {layer.weight_scale.dtype}"
)
if self.backend == Mxfp8LinearBackend.EMULATION:
# Swizzled layout is not used
self._process_weights_after_loading_scale_2d(layer)
return
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
# Swizzled layout is required for Flashinfer CUTLASS
self._process_weights_after_loading_scale_1d(layer)
self.mxfp8_linear_op.process_weights(layer)
def apply(
self,
@@ -1684,22 +1643,15 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if layer.weight.dtype != MXFP8_VALUE_DTYPE:
raise ValueError(
f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}"
)
if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"Weight scale dtype {layer.weight_scale.dtype} != "
f"expected {MXFP8_SCALE_DTYPE}"
)
return self.mxfp8_linear_op.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
workspace=getattr(layer, "workspace", None),
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)

View File

@@ -34,10 +34,8 @@ from vllm.model_executor.layers.quantization.fp8 import (
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
Mxfp8LinearBackend,
Mxfp8LinearOp,
mxfp8_e4m3_quantize,
swizzle_mxfp8_scale,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
@@ -71,7 +69,8 @@ class Mxfp8Config(Fp8Config):
@classmethod
def get_min_capability(cls) -> int:
return 100
# Marlin kernel supports MXFP8 on SM80+
return 80
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
@@ -128,24 +127,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
def __init__(self, quant_config: "Mxfp8Config"):
self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype()
self.mxfp8_linear = Mxfp8LinearOp(self._select_backend())
logger.info_once(
"Using %s backend for MXFP8 GEMM", self.mxfp8_linear.backend.value
)
@staticmethod
def _select_backend() -> Mxfp8LinearBackend:
try:
from vllm.utils import flashinfer as fi
_ = fi.mm_mxfp8
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
except Exception:
logger.warning(
"FlashInfer mm_mxfp8 not available, "
"falling back to MXFP8 emulation backend."
)
return Mxfp8LinearBackend.EMULATION
self.mxfp8_linear = Mxfp8LinearOp()
def create_weights(
self,
@@ -180,14 +162,12 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
weight_fp8, weight_scale = mxfp8_e4m3_quantize(layer.weight.contiguous())
if self.mxfp8_linear.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
N, K = layer.weight.shape[0], layer.weight.shape[1]
weight_scale = swizzle_mxfp8_scale(weight_scale, N, K)
layer.input_scale = None
replace_parameter(layer, "weight", weight_fp8.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
self.mxfp8_linear.process_weights(layer)
layer._already_called_process_weights_after_loading = True
def apply(
@@ -202,6 +182,9 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
bias=bias,
workspace=getattr(layer, "workspace", None),
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)
@@ -255,17 +238,24 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
self, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
num_batches = weight.size(0)
w_quant = []
w_scales = []
for i in range(num_batches):
mx_fp8_quant, mx_fp8_scale = mxfp8_e4m3_quantize(
E = weight.size(0)
first_q, first_s = mxfp8_e4m3_quantize(weight[0], is_sf_swizzled_layout=False)
# Pre-allocate the output tensors rather than stacking.
# This is important for consistent memory layout.
w_quant = torch.empty(
(E, *first_q.shape), dtype=first_q.dtype, device=weight.device
)
w_scales = torch.empty(
(E, *first_s.shape), dtype=first_s.dtype, device=weight.device
)
w_quant[0] = first_q
w_scales[0] = first_s
for i in range(1, E):
w_quant[i], w_scales[i] = mxfp8_e4m3_quantize(
weight[i], is_sf_swizzled_layout=False
)
w_quant.append(mx_fp8_quant)
w_scales.append(mx_fp8_scale)
return torch.stack(w_quant), torch.stack(w_scales)
return w_quant, w_scales
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):

View File

@@ -336,6 +336,202 @@ def pack_fp8_to_int32(
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def mxfp8_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor:
"""Reorder scales for e8m0 kernel layout and convert to float8_e8m0fnu."""
# fit the layout of fp8 dequantization
marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
marlin_scales.size(0), -1
)
marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
return marlin_scales
def apply_mxfp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=size_n,
k=size_k,
device=input.device,
dtype=input.dtype,
)
output = ops.marlin_gemm(
a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
a_scales=None,
global_scale=None,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float8_e4m3fn,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
)
return output.reshape(out_shape)
def prepare_mxfp8_layer_for_marlin(layer: torch.nn.Module) -> None:
"""Repack MXFP8 weights and scales into Marlin kernel format.
Expects the layer to have:
- weight: [N, K] float8_e4m3fn
- weight_scale: [N, K//32] uint8 (e8m0 encoded)
- input_size_per_partition / output_size_per_partition
"""
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
group_size = 32 # MX standard block size
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace_new(device)
# WEIGHT - repack FP8 weights to Marlin format
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = pack_fp8_to_int32(layer.weight, size_k_first=False)
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
replace_parameter(layer, "weight", marlin_qweight)
# WEIGHT SCALES
# Convert uint8 scales -> e8m0fnu -> param_dtype for permutation
# Scales are [N, K//32], need [K//32, N] for marlin_permute_scales
param_dtype = torch.get_default_dtype()
scales = layer.weight_scale.data[:part_size_n, : part_size_k // group_size]
scales = scales.contiguous()
scales = scales.view(torch.float8_e8m0fnu).to(param_dtype)
scales = scales.T.contiguous()
# Permute scales to Marlin layout
marlin_scales = marlin_permute_scales(
s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=group_size,
)
# Reorder for e8m0 kernel layout and convert back to e8m0fnu
marlin_scales = mxfp8_marlin_process_scales(marlin_scales)
replace_parameter(layer, "weight_scale", marlin_scales)
# BIAS
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n,)
bias = marlin_permute_bias(layer.bias)
replace_parameter(layer, "bias", bias)
def prepare_mxfp8_moe_layer_for_marlin(
layer: torch.nn.Module,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Repack MXFP8 MoE weights and scales into Marlin kernel format.
Args:
layer: MoE layer (used to read params_dtype and attach workspace).
w13: [E, 2*N, K] float8_e4m3fn weights.
w2: [E, K, N] float8_e4m3fn weights.
w13_scale: [E, 2*N, K//32] uint8 e8m0 scales.
w2_scale: [E, K, N//32] uint8 e8m0 scales.
Returns:
(w13, w2, w13_scale, w2_scale) in Marlin format.
"""
group_size = 32
e = w13.shape[0]
w13_n = w13.shape[1]
k = w13.shape[2]
n = w2.shape[2]
device = w13.device
param_dtype = torch.get_default_dtype()
perm = torch.empty(0, dtype=torch.int, device=device)
layer.workspace = marlin_make_workspace_new(device, 4)
def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
if "w13" in name:
size_n, size_k = w13_n, k
else:
size_n, size_k = k, n
assert weight.shape == (e, size_n, size_k)
tensor_list = []
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
qweight = qweight.T.contiguous()
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=size_k,
size_n=size_n,
num_bits=8,
)
tensor_list.append(marlin_qweight)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13 = repack_weight(w13, "w13")
w2 = repack_weight(w2, "w2")
def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
if "w13" in name:
size_n, size_k = w13_n, k
else:
size_n, size_k = k, n
tensor_list = []
for i in range(e):
s = scales[i][:size_n, : size_k // group_size].contiguous()
s = s.view(torch.float8_e8m0fnu).to(param_dtype)
s = s.T.contiguous()
marlin_s = marlin_permute_scales(
s=s,
size_k=size_k,
size_n=size_n,
group_size=group_size,
)
marlin_s = mxfp8_marlin_process_scales(marlin_s)
tensor_list.append(marlin_s)
return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
w13_scale = permute_scales(w13_scale, "w13")
w2_scale = permute_scales(w2_scale, "w2")
return w13, w2, w13_scale, w2_scale
def marlin_quant_fp8_torch(weight, group_size, input_dtype=None):
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
if is_a_8bit:

View File

@@ -4,6 +4,7 @@
from enum import Enum
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.utils import flashinfer as vllm_flashinfer
@@ -15,6 +16,7 @@ logger = init_logger(__name__)
class Mxfp8LinearBackend(Enum):
EMULATION = "emulation"
FLASHINFER_CUTLASS = "flashinfer-cutlass"
MARLIN = "marlin"
# MXFP8 constants
@@ -23,6 +25,28 @@ MXFP8_SCALE_DTYPE = torch.uint8
MXFP8_BLOCK_SIZE = 32
def select_mxfp8_linear_backend() -> Mxfp8LinearBackend:
"""Select the best MXFP8 linear backend for the current device.
- SM100+ (Blackwell): FLASHINFER_CUTLASS (native MXFP8 W8A8 GEMM)
- SM80+ (Ampere/Ada): MARLIN (MXFP8 W8A16 GEMM)
- Otherwise: EMULATION (dequant to BF16 fallback)
"""
from vllm.platforms import current_platform
if current_platform.has_device_capability(100):
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
is_fp8_marlin_supported,
)
if is_fp8_marlin_supported():
return Mxfp8LinearBackend.MARLIN
return Mxfp8LinearBackend.EMULATION
def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
@@ -47,17 +71,71 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
return sf_swizzled.contiguous().view(-1)
def _mxfp8_e4m3_quantize_torch(
x: torch.Tensor,
is_sf_swizzled_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Naive MXFP8 quantization.
For each block of 32 elements along the last dimension, compute a
shared e8m0 scale (the biased exponent of the block-wise amax)
and quantize each element to float8_e4m3fn.
Returns (quantized_values [same shape, fp8], scales uint8).
Scale shape depends on is_sf_swizzled_layout:
False -> [..., K//32] (row-major 2D)
True -> [flat swizzled 1D]
"""
assert x.shape[-1] % MXFP8_BLOCK_SIZE == 0
orig_shape = x.shape
num_blocks = x.shape[-1] // MXFP8_BLOCK_SIZE
x_fp32 = x.to(torch.float32)
x_blocked = x_fp32.view(*orig_shape[:-1], num_blocks, MXFP8_BLOCK_SIZE)
amax = x_blocked.abs().amax(dim=-1)
amax = amax.clamp(min=torch.finfo(torch.float32).tiny)
scale_biased = torch.floor(torch.log2(amax)) + 127.0
scale_biased = scale_biased.clamp(0, 254)
scales_uint8 = scale_biased.to(torch.uint8)
descale = torch.exp2(scale_biased - 127.0)
x_scaled = x_blocked / descale.unsqueeze(-1)
x_fp8 = x_scaled.view(orig_shape).to(MXFP8_VALUE_DTYPE)
if x.ndim == 2:
M, K = x.shape
scales_uint8 = scales_uint8.view(M, -1)
if is_sf_swizzled_layout:
scales_uint8 = swizzle_mxfp8_scale(scales_uint8, M=M, K=K)
elif x.ndim == 3:
B, M, K = x.shape
scales_uint8 = scales_uint8.view(B, M, -1)
if is_sf_swizzled_layout:
swizzled = []
for i in range(B):
swizzled.append(swizzle_mxfp8_scale(scales_uint8[i], M=M, K=K))
scales_uint8 = torch.cat(swizzled)
return x_fp8, scales_uint8
def _mxfp8_e4m3_quantize_impl(
x: torch.Tensor, is_sf_swizzled_layout: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
from vllm.platforms import current_platform
x_q, x_scales = flashinfer_mxfp8_quantize(
x, is_sf_swizzled_layout=is_sf_swizzled_layout
)
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
x_scales = x_scales.view(x.size(0), -1)
return x_q, x_scales
if current_platform.has_device_capability(100):
from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize
x_q, x_scales = flashinfer_mxfp8_quantize(
x, is_sf_swizzled_layout=is_sf_swizzled_layout
)
if x_scales.ndim == 1 and x.ndim == 2 and not is_sf_swizzled_layout:
x_scales = x_scales.view(x.size(0), -1)
return x_q, x_scales
return _mxfp8_e4m3_quantize_torch(x, is_sf_swizzled_layout)
def mxfp8_e4m3_quantize(
@@ -128,11 +206,51 @@ direct_register_custom_op(
class Mxfp8LinearOp:
def __init__(self, backend: Mxfp8LinearBackend):
if backend not in Mxfp8LinearBackend:
raise ValueError(f"Unsupported backend: {backend}")
def __init__(self):
self.backend = select_mxfp8_linear_backend()
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend)
self.backend = backend
def process_weights(self, layer: torch.nn.Module) -> None:
"""Process MXFP8 weights after loading into backend-specific format."""
if self.backend == Mxfp8LinearBackend.MARLIN:
self._process_weights_marlin(layer)
elif self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
self._process_weights_flashinfer_cutlass(layer)
else:
self._process_weights_emulation(layer)
def _process_weights_emulation(self, layer: torch.nn.Module) -> None:
"""Keep scales as 2D uint8 for dequant-to-BF16 emulation."""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def _process_weights_flashinfer_cutlass(self, layer: torch.nn.Module) -> None:
"""Swizzle scales to F8_128x4 layout for flashinfer CUTLASS."""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale_2d = layer.weight_scale.data[:N, :scale_k].contiguous()
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(
weight_scale_swizzled.contiguous(), requires_grad=False
)
def _process_weights_marlin(self, layer: torch.nn.Module) -> None:
"""Repack MXFP8 weights and scales into Marlin kernel format."""
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_mxfp8_layer_for_marlin,
)
prepare_mxfp8_layer_for_marlin(layer)
def _apply_emulation(
self,
@@ -142,7 +260,6 @@ class Mxfp8LinearOp:
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
@@ -219,6 +336,32 @@ class Mxfp8LinearOp:
output_shape = (*input_shape[:-1], N)
return output.view(output_shape)
def _apply_marlin(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
*,
workspace: torch.Tensor,
size_n: int,
size_k: int,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_mxfp8_marlin_linear,
)
return apply_mxfp8_marlin_linear(
input=input,
weight=weight,
weight_scale=weight_scale,
workspace=workspace,
size_n=size_n,
size_k=size_k,
bias=bias,
)
def apply(
self,
input: torch.Tensor,
@@ -226,10 +369,27 @@ class Mxfp8LinearOp:
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
*,
workspace: torch.Tensor | None = None,
size_n: int = 0,
size_k: int = 0,
) -> torch.Tensor:
if self.backend == Mxfp8LinearBackend.EMULATION:
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)
if self.backend == Mxfp8LinearBackend.MARLIN:
assert workspace is not None
return self._apply_marlin(
input,
weight,
weight_scale,
out_dtype,
bias,
workspace=workspace,
size_n=size_n,
size_k=size_k,
)
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
return self._apply_flashinfer_cutlass(
input, weight, weight_scale, out_dtype, bias