From 5d3b6097ad6f8ddb156b92525cce19e08e5c7b89 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 8 Jan 2026 17:45:17 -0500 Subject: [PATCH] [Compressed-Tensors] Simplify NVFP4 Conditions, enable marlin support for NVFP4A16 MoEs (#30881) --- .../compressed_tensors/compressed_tensors.py | 49 +++++-------------- .../compressed_tensors_moe.py | 35 +++++++++++-- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index f83558421..bc4fdfdda 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -331,25 +331,17 @@ class CompressedTensorsConfig(QuantizationConfig): return False @staticmethod - def _is_fp4a4_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs): - if weight_quant is None or input_quant is None: + def _is_nvfp4_format(quant_args: QuantizationArgs): + if quant_args is None: return False - is_tensor_group_quant = ( - weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value - and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + quant_args.strategy == QuantizationStrategy.TENSOR_GROUP.value ) - is_symmetric = weight_quant.symmetric and input_quant.symmetric - - is_group_size_16 = ( - weight_quant.group_size == 16 and input_quant.group_size == 16 - ) - is_float_type = ( - weight_quant.type == QuantizationType.FLOAT - and input_quant.type == QuantizationType.FLOAT - ) - is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 + is_symmetric = quant_args.symmetric + is_group_size_16 = quant_args.group_size == 16 + is_float_type = quant_args.type == QuantizationType.FLOAT + is_4_bits = quant_args.num_bits == 4 return ( is_tensor_group_quant and is_float_type @@ -358,27 +350,6 @@ class CompressedTensorsConfig(QuantizationConfig): and is_symmetric ) - @staticmethod - def _is_fp4a16_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs): - is_weight_only = weight_quant is not None and input_quant is None - is_tensor_group_quant = ( - weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value - ) - is_symmetric = weight_quant.symmetric - - is_group_size_16 = weight_quant.group_size == 16 - is_float_type = weight_quant.type == QuantizationType.FLOAT - is_4_bits = weight_quant.num_bits == 4 - - return ( - is_weight_only - and is_tensor_group_quant - and is_float_type - and is_4_bits - and is_group_size_16 - and is_symmetric - ) - @staticmethod def _is_static_tensor_w8a8( weight_quant: QuantizationArgs, input_quant: QuantizationArgs @@ -576,7 +547,7 @@ class CompressedTensorsConfig(QuantizationConfig): format = format if format is not None else self.quant_format # Detect If Mixed Precision - if self._is_fp4a16_nvfp4(weight_quant, input_quant): + if self._is_nvfp4_format(weight_quant) and input_quant is None: return CompressedTensorsW4A16Fp4() if self._is_fp8_w4a8_sm90(weight_quant, input_quant): @@ -614,7 +585,9 @@ class CompressedTensorsConfig(QuantizationConfig): act_quant_format = is_activation_quantization_format(format) if act_quant_format: - if self._is_fp4a4_nvfp4(weight_quant, input_quant): + if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format( + input_quant + ): if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: return CompressedTensorsW4A4Fp4() else: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 86878c84a..12f37598d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -77,6 +77,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_make_workspace_new, marlin_moe_permute_scales, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + is_fp4_marlin_supported, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, @@ -182,8 +185,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): return CompressedTensorsWNA16MarlinMoEMethod( weight_quant, input_quant, layer.moe_config ) - elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name) + elif quant_config._is_nvfp4_format(weight_quant): + _is_valid_nvfp4_activations = ( + quant_config._is_nvfp4_format(input_quant) or input_quant is None + ) + if not _is_valid_nvfp4_activations: + raise ValueError( + "For NVFP4 weights, input quantization must also be NVFP4 format ", + f"or None for NVFP4A16, found {input_quant}", + ) + return CompressedTensorsW4A4Nvfp4MoEMethod( + layer.moe_config, layer_name, use_marlin=input_quant is None + ) elif ( quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) @@ -212,7 +225,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): - def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None): + def __init__( + self, + moe: FusedMoEConfig, + layer_name: str | None = None, + use_marlin: bool = False, + ): if not moe.is_act_and_mul: raise ValueError( "CompressedTensorsW4A4Nvfp4MoEMethod does not yet " @@ -221,7 +239,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): super().__init__(moe) self.group_size = 16 - self.nvfp4_backend = select_nvfp4_moe_backend() + if use_marlin: + if is_fp4_marlin_supported(): + self.nvfp4_backend = NvFp4MoeBackend.MARLIN + else: + raise ValueError( + "Marlin FP4 MoE kernel requested but not ", + "supported on current platform.", + ) + else: + self.nvfp4_backend = select_nvfp4_moe_backend() self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend )