[Quant] Support MXFP4 W4A16 for compressed-tensors dense models (#31926)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-01-10 09:44:35 -05:00
committed by GitHub
parent 07286ec5a6
commit e6c6f2c79d
3 changed files with 131 additions and 0 deletions

View File

@@ -43,6 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Fp8,
CompressedTensorsW4A8Int, CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Mxfp4,
CompressedTensorsW4A16Sparse24, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsW8A8Int8,
@@ -350,6 +351,25 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric and is_symmetric
) )
@staticmethod
def _is_mxfp4(quant_args: QuantizationArgs) -> bool:
if quant_args is None:
return False
is_group_quant = quant_args.strategy == QuantizationStrategy.GROUP.value
is_symmetric = quant_args.symmetric
is_group_size_32 = quant_args.group_size == 32
is_float_type = quant_args.type == QuantizationType.FLOAT
is_4_bits = quant_args.num_bits == 4
return (
is_group_quant
and is_float_type
and is_4_bits
and is_group_size_32
and is_symmetric
)
@staticmethod @staticmethod
def _is_static_tensor_w8a8( def _is_static_tensor_w8a8(
weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
@@ -550,6 +570,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_nvfp4_format(weight_quant) and input_quant is None: if self._is_nvfp4_format(weight_quant) and input_quant is None:
return CompressedTensorsW4A16Fp4() return CompressedTensorsW4A16Fp4()
if self._is_mxfp4(weight_quant):
return CompressedTensorsW4A16Mxfp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant): if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8( return CompressedTensorsW4A8Fp8(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,

View File

@@ -9,6 +9,7 @@ from .compressed_tensors_w4a16_24 import (
W4A16SPARSE24_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24, CompressedTensorsW4A16Sparse24,
) )
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
@@ -29,6 +30,7 @@ __all__ = [
"W4A16SPARSE24_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24", "CompressedTensors24",
"CompressedTensorsW4A16Fp4", "CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A16Mxfp4",
"CompressedTensorsW4A4Fp4", "CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int", "CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8", "CompressedTensorsW4A8Fp8",

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
)
__all__ = ["CompressedTensorsW4A16Mxfp4"]
class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
"""
Compressed tensors scheme for MXFP4 weight-only quantization.
Supports models quantized with the compressed-tensors mxfp4-pack-quantized
format.
MXFP4 format:
- 4-bit float weights (E2M1) packed into uint8
- Per-group E8M0 scales with group_size=32
- No global scale (unlike NVFP4)
"""
def __init__(self):
self.group_size = 32
@classmethod
def get_min_capability(cls) -> int:
return 80
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.params_dtype = params_dtype
# Packed FP4 weights (2 values per byte)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
# Per-group E8M0 scales
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.group_size,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer) -> None:
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
prepare_fp4_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=None,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)