[Compressed-Tensors] Simplify NVFP4 Conditions, enable marlin support for NVFP4A16 MoEs (#30881)

This commit is contained in:
Dipika Sikka
2026-01-08 17:45:17 -05:00
committed by GitHub
parent e74698c27a
commit 5d3b6097ad
2 changed files with 42 additions and 42 deletions

View File

@@ -331,25 +331,17 @@ class CompressedTensorsConfig(QuantizationConfig):
return False
@staticmethod
def _is_fp4a4_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
if weight_quant is None or input_quant is None:
def _is_nvfp4_format(quant_args: QuantizationArgs):
if quant_args is None:
return False
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
quant_args.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (
weight_quant.group_size == 16 and input_quant.group_size == 16
)
is_float_type = (
weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT
)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
is_symmetric = quant_args.symmetric
is_group_size_16 = quant_args.group_size == 16
is_float_type = quant_args.type == QuantizationType.FLOAT
is_4_bits = quant_args.num_bits == 4
return (
is_tensor_group_quant
and is_float_type
@@ -358,27 +350,6 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric
)
@staticmethod
def _is_fp4a16_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric
is_group_size_16 = weight_quant.group_size == 16
is_float_type = weight_quant.type == QuantizationType.FLOAT
is_4_bits = weight_quant.num_bits == 4
return (
is_weight_only
and is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)
@staticmethod
def _is_static_tensor_w8a8(
weight_quant: QuantizationArgs, input_quant: QuantizationArgs
@@ -576,7 +547,7 @@ class CompressedTensorsConfig(QuantizationConfig):
format = format if format is not None else self.quant_format
# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
if self._is_nvfp4_format(weight_quant) and input_quant is None:
return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
@@ -614,7 +585,9 @@ class CompressedTensorsConfig(QuantizationConfig):
act_quant_format = is_activation_quantization_format(format)
if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
input_quant
):
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:

View File

@@ -77,6 +77,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new,
marlin_moe_permute_scales,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
@@ -182,8 +185,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsWNA16MarlinMoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
elif quant_config._is_nvfp4_format(weight_quant):
_is_valid_nvfp4_activations = (
quant_config._is_nvfp4_format(input_quant) or input_quant is None
)
if not _is_valid_nvfp4_activations:
raise ValueError(
"For NVFP4 weights, input quantization must also be NVFP4 format ",
f"or None for NVFP4A16, found {input_quant}",
)
return CompressedTensorsW4A4Nvfp4MoEMethod(
layer.moe_config, layer_name, use_marlin=input_quant is None
)
elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
@@ -212,7 +225,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
def __init__(
self,
moe: FusedMoEConfig,
layer_name: str | None = None,
use_marlin: bool = False,
):
if not moe.is_act_and_mul:
raise ValueError(
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
@@ -221,6 +239,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
super().__init__(moe)
self.group_size = 16
if use_marlin:
if is_fp4_marlin_supported():
self.nvfp4_backend = NvFp4MoeBackend.MARLIN
else:
raise ValueError(
"Marlin FP4 MoE kernel requested but not ",
"supported on current platform.",
)
else:
self.nvfp4_backend = select_nvfp4_moe_backend()
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend