[Compressed-Tensors] Simplify NVFP4 Conditions, enable marlin support for NVFP4A16 MoEs (#30881)

This commit is contained in:
Dipika Sikka
2026-01-08 17:45:17 -05:00
committed by GitHub
parent e74698c27a
commit 5d3b6097ad
2 changed files with 42 additions and 42 deletions

View File

@@ -331,25 +331,17 @@ class CompressedTensorsConfig(QuantizationConfig):
return False return False
@staticmethod @staticmethod
def _is_fp4a4_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs): def _is_nvfp4_format(quant_args: QuantizationArgs):
if weight_quant is None or input_quant is None: if quant_args is None:
return False return False
is_tensor_group_quant = ( is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value quant_args.strategy == QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
) )
is_symmetric = weight_quant.symmetric and input_quant.symmetric is_symmetric = quant_args.symmetric
is_group_size_16 = (
weight_quant.group_size == 16 and input_quant.group_size == 16
)
is_float_type = (
weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT
)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
is_group_size_16 = quant_args.group_size == 16
is_float_type = quant_args.type == QuantizationType.FLOAT
is_4_bits = quant_args.num_bits == 4
return ( return (
is_tensor_group_quant is_tensor_group_quant
and is_float_type and is_float_type
@@ -358,27 +350,6 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric and is_symmetric
) )
@staticmethod
def _is_fp4a16_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric
is_group_size_16 = weight_quant.group_size == 16
is_float_type = weight_quant.type == QuantizationType.FLOAT
is_4_bits = weight_quant.num_bits == 4
return (
is_weight_only
and is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)
@staticmethod @staticmethod
def _is_static_tensor_w8a8( def _is_static_tensor_w8a8(
weight_quant: QuantizationArgs, input_quant: QuantizationArgs weight_quant: QuantizationArgs, input_quant: QuantizationArgs
@@ -576,7 +547,7 @@ class CompressedTensorsConfig(QuantizationConfig):
format = format if format is not None else self.quant_format format = format if format is not None else self.quant_format
# Detect If Mixed Precision # Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant): if self._is_nvfp4_format(weight_quant) and input_quant is None:
return CompressedTensorsW4A16Fp4() return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant): if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
@@ -614,7 +585,9 @@ class CompressedTensorsConfig(QuantizationConfig):
act_quant_format = is_activation_quantization_format(format) act_quant_format = is_activation_quantization_format(format)
if act_quant_format: if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant): if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
input_quant
):
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS: if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4() return CompressedTensorsW4A4Fp4()
else: else:

View File

@@ -77,6 +77,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales, marlin_moe_permute_scales,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
is_fp4_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8, convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace, convert_packed_uint4b8_to_signed_int4_inplace,
@@ -182,8 +185,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsWNA16MarlinMoEMethod( return CompressedTensorsWNA16MarlinMoEMethod(
weight_quant, input_quant, layer.moe_config weight_quant, input_quant, layer.moe_config
) )
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_nvfp4_format(weight_quant):
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name) _is_valid_nvfp4_activations = (
quant_config._is_nvfp4_format(input_quant) or input_quant is None
)
if not _is_valid_nvfp4_activations:
raise ValueError(
"For NVFP4 weights, input quantization must also be NVFP4 format ",
f"or None for NVFP4A16, found {input_quant}",
)
return CompressedTensorsW4A4Nvfp4MoEMethod(
layer.moe_config, layer_name, use_marlin=input_quant is None
)
elif ( elif (
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
@@ -212,7 +225,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None): def __init__(
self,
moe: FusedMoEConfig,
layer_name: str | None = None,
use_marlin: bool = False,
):
if not moe.is_act_and_mul: if not moe.is_act_and_mul:
raise ValueError( raise ValueError(
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet " "CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
@@ -221,6 +239,15 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
super().__init__(moe) super().__init__(moe)
self.group_size = 16 self.group_size = 16
if use_marlin:
if is_fp4_marlin_supported():
self.nvfp4_backend = NvFp4MoeBackend.MARLIN
else:
raise ValueError(
"Marlin FP4 MoE kernel requested but not ",
"supported on current platform.",
)
else:
self.nvfp4_backend = select_nvfp4_moe_backend() self.nvfp4_backend = select_nvfp4_moe_backend()
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
self.nvfp4_backend self.nvfp4_backend