From e6c6f2c79d2c1e0765d17b4dec83a4a8283342e4 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 10 Jan 2026 09:44:35 -0500 Subject: [PATCH] [Quant] Support MXFP4 W4A16 for compressed-tensors dense models (#31926) Signed-off-by: mgoin Signed-off-by: Michael Goin --- .../compressed_tensors/compressed_tensors.py | 23 ++++ .../compressed_tensors/schemes/__init__.py | 2 + .../schemes/compressed_tensors_w4a16_mxfp4.py | 106 ++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index bc4fdfdda..de50c9e8f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -43,6 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, + CompressedTensorsW4A16Mxfp4, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, @@ -350,6 +351,25 @@ class CompressedTensorsConfig(QuantizationConfig): and is_symmetric ) + @staticmethod + def _is_mxfp4(quant_args: QuantizationArgs) -> bool: + if quant_args is None: + return False + + is_group_quant = quant_args.strategy == QuantizationStrategy.GROUP.value + is_symmetric = quant_args.symmetric + is_group_size_32 = quant_args.group_size == 32 + is_float_type = quant_args.type == QuantizationType.FLOAT + is_4_bits = quant_args.num_bits == 4 + + return ( + is_group_quant + and is_float_type + and is_4_bits + and is_group_size_32 + and is_symmetric + ) + @staticmethod def _is_static_tensor_w8a8( weight_quant: QuantizationArgs, input_quant: QuantizationArgs @@ -550,6 +570,9 @@ class CompressedTensorsConfig(QuantizationConfig): if self._is_nvfp4_format(weight_quant) and input_quant is None: return CompressedTensorsW4A16Fp4() + if self._is_mxfp4(weight_quant): + return CompressedTensorsW4A16Mxfp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): return CompressedTensorsW4A8Fp8( num_bits=weight_quant.num_bits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index ca286675e..6d40685f0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -9,6 +9,7 @@ from .compressed_tensors_w4a16_24 import ( W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24, ) +from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 @@ -29,6 +30,7 @@ __all__ = [ "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", + "CompressedTensorsW4A16Mxfp4", "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", "CompressedTensorsW4A8Fp8", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py new file mode 100644 index 000000000..1c76adebe --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_mxfp4.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, +) + +__all__ = ["CompressedTensorsW4A16Mxfp4"] + + +class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme): + """ + Compressed tensors scheme for MXFP4 weight-only quantization. + + Supports models quantized with the compressed-tensors mxfp4-pack-quantized + format. + + MXFP4 format: + - 4-bit float weights (E2M1) packed into uint8 + - Per-group E8M0 scales with group_size=32 + - No global scale (unlike NVFP4) + """ + + def __init__(self): + self.group_size = 32 + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.params_dtype = params_dtype + + # Packed FP4 weights (2 values per byte) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_packed", weight) + + # Per-group E8M0 scales + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.group_size, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer) -> None: + # Rename weight_packed to weight that marlin expects + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + del layer.weight_packed + + prepare_fp4_layer_for_marlin(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=None, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + )