[Quant] Support MXFP4 W4A16 for compressed-tensors dense models (#31926)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -43,6 +43,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsW4A8Fp8,
|
CompressedTensorsW4A8Fp8,
|
||||||
CompressedTensorsW4A8Int,
|
CompressedTensorsW4A8Int,
|
||||||
CompressedTensorsW4A16Fp4,
|
CompressedTensorsW4A16Fp4,
|
||||||
|
CompressedTensorsW4A16Mxfp4,
|
||||||
CompressedTensorsW4A16Sparse24,
|
CompressedTensorsW4A16Sparse24,
|
||||||
CompressedTensorsW8A8Fp8,
|
CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A8Int8,
|
CompressedTensorsW8A8Int8,
|
||||||
@@ -350,6 +351,25 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
and is_symmetric
|
and is_symmetric
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_mxfp4(quant_args: QuantizationArgs) -> bool:
|
||||||
|
if quant_args is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_group_quant = quant_args.strategy == QuantizationStrategy.GROUP.value
|
||||||
|
is_symmetric = quant_args.symmetric
|
||||||
|
is_group_size_32 = quant_args.group_size == 32
|
||||||
|
is_float_type = quant_args.type == QuantizationType.FLOAT
|
||||||
|
is_4_bits = quant_args.num_bits == 4
|
||||||
|
|
||||||
|
return (
|
||||||
|
is_group_quant
|
||||||
|
and is_float_type
|
||||||
|
and is_4_bits
|
||||||
|
and is_group_size_32
|
||||||
|
and is_symmetric
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_static_tensor_w8a8(
|
def _is_static_tensor_w8a8(
|
||||||
weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||||
@@ -550,6 +570,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
if self._is_nvfp4_format(weight_quant) and input_quant is None:
|
if self._is_nvfp4_format(weight_quant) and input_quant is None:
|
||||||
return CompressedTensorsW4A16Fp4()
|
return CompressedTensorsW4A16Fp4()
|
||||||
|
|
||||||
|
if self._is_mxfp4(weight_quant):
|
||||||
|
return CompressedTensorsW4A16Mxfp4()
|
||||||
|
|
||||||
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A8Fp8(
|
return CompressedTensorsW4A8Fp8(
|
||||||
num_bits=weight_quant.num_bits,
|
num_bits=weight_quant.num_bits,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from .compressed_tensors_w4a16_24 import (
|
|||||||
W4A16SPARSE24_SUPPORTED_BITS,
|
W4A16SPARSE24_SUPPORTED_BITS,
|
||||||
CompressedTensorsW4A16Sparse24,
|
CompressedTensorsW4A16Sparse24,
|
||||||
)
|
)
|
||||||
|
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
|
||||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||||
@@ -29,6 +30,7 @@ __all__ = [
|
|||||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||||
"CompressedTensors24",
|
"CompressedTensors24",
|
||||||
"CompressedTensorsW4A16Fp4",
|
"CompressedTensorsW4A16Fp4",
|
||||||
|
"CompressedTensorsW4A16Mxfp4",
|
||||||
"CompressedTensorsW4A4Fp4",
|
"CompressedTensorsW4A4Fp4",
|
||||||
"CompressedTensorsW4A8Int",
|
"CompressedTensorsW4A8Int",
|
||||||
"CompressedTensorsW4A8Fp8",
|
"CompressedTensorsW4A8Fp8",
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
|
apply_fp4_marlin_linear,
|
||||||
|
prepare_fp4_layer_for_marlin,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parameter import (
|
||||||
|
GroupQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW4A16Mxfp4"]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A16Mxfp4(CompressedTensorsScheme):
|
||||||
|
"""
|
||||||
|
Compressed tensors scheme for MXFP4 weight-only quantization.
|
||||||
|
|
||||||
|
Supports models quantized with the compressed-tensors mxfp4-pack-quantized
|
||||||
|
format.
|
||||||
|
|
||||||
|
MXFP4 format:
|
||||||
|
- 4-bit float weights (E2M1) packed into uint8
|
||||||
|
- Per-group E8M0 scales with group_size=32
|
||||||
|
- No global scale (unlike NVFP4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.group_size = 32
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
weight_loader: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.params_dtype = params_dtype
|
||||||
|
|
||||||
|
# Packed FP4 weights (2 values per byte)
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition // 2,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight_packed", weight)
|
||||||
|
|
||||||
|
# Per-group E8M0 scales
|
||||||
|
weight_scale = GroupQuantScaleParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition // self.group_size,
|
||||||
|
dtype=torch.uint8,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
|
# Rename weight_packed to weight that marlin expects
|
||||||
|
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||||
|
del layer.weight_packed
|
||||||
|
|
||||||
|
prepare_fp4_layer_for_marlin(layer)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return apply_fp4_marlin_linear(
|
||||||
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
weight_scale_2=None,
|
||||||
|
workspace=layer.workspace,
|
||||||
|
size_n=layer.output_size_per_partition,
|
||||||
|
size_k=layer.input_size_per_partition,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user