[Compressed-Tensors] Simplify NVFP4 Conditions, enable marlin support for NVFP4A16 MoEs (#30881)
This commit is contained in:
@@ -331,25 +331,17 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_fp4a4_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
|
||||
if weight_quant is None or input_quant is None:
|
||||
def _is_nvfp4_format(quant_args: QuantizationArgs):
|
||||
if quant_args is None:
|
||||
return False
|
||||
|
||||
is_tensor_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
quant_args.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
|
||||
is_group_size_16 = (
|
||||
weight_quant.group_size == 16 and input_quant.group_size == 16
|
||||
)
|
||||
is_float_type = (
|
||||
weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT
|
||||
)
|
||||
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
|
||||
is_symmetric = quant_args.symmetric
|
||||
|
||||
is_group_size_16 = quant_args.group_size == 16
|
||||
is_float_type = quant_args.type == QuantizationType.FLOAT
|
||||
is_4_bits = quant_args.num_bits == 4
|
||||
return (
|
||||
is_tensor_group_quant
|
||||
and is_float_type
|
||||
@@ -358,27 +350,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
and is_symmetric
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_fp4a16_nvfp4(weight_quant: QuantizationArgs, input_quant: QuantizationArgs):
|
||||
is_weight_only = weight_quant is not None and input_quant is None
|
||||
is_tensor_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
)
|
||||
is_symmetric = weight_quant.symmetric
|
||||
|
||||
is_group_size_16 = weight_quant.group_size == 16
|
||||
is_float_type = weight_quant.type == QuantizationType.FLOAT
|
||||
is_4_bits = weight_quant.num_bits == 4
|
||||
|
||||
return (
|
||||
is_weight_only
|
||||
and is_tensor_group_quant
|
||||
and is_float_type
|
||||
and is_4_bits
|
||||
and is_group_size_16
|
||||
and is_symmetric
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_static_tensor_w8a8(
|
||||
weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
@@ -576,7 +547,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
format = format if format is not None else self.quant_format
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
if self._is_nvfp4_format(weight_quant) and input_quant is None:
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
|
||||
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
||||
@@ -614,7 +585,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
if self._is_nvfp4_format(weight_quant) and self._is_nvfp4_format(
|
||||
input_quant
|
||||
):
|
||||
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
else:
|
||||
|
||||
@@ -77,6 +77,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
is_fp4_marlin_supported,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
@@ -182,8 +185,18 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||
weight_quant, input_quant, layer.moe_config
|
||||
)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4Nvfp4MoEMethod(layer.moe_config, layer_name)
|
||||
elif quant_config._is_nvfp4_format(weight_quant):
|
||||
_is_valid_nvfp4_activations = (
|
||||
quant_config._is_nvfp4_format(input_quant) or input_quant is None
|
||||
)
|
||||
if not _is_valid_nvfp4_activations:
|
||||
raise ValueError(
|
||||
"For NVFP4 weights, input quantization must also be NVFP4 format ",
|
||||
f"or None for NVFP4A16, found {input_quant}",
|
||||
)
|
||||
return CompressedTensorsW4A4Nvfp4MoEMethod(
|
||||
layer.moe_config, layer_name, use_marlin=input_quant is None
|
||||
)
|
||||
elif (
|
||||
quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||
@@ -212,7 +225,12 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(self, moe: FusedMoEConfig, layer_name: str | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
layer_name: str | None = None,
|
||||
use_marlin: bool = False,
|
||||
):
|
||||
if not moe.is_act_and_mul:
|
||||
raise ValueError(
|
||||
"CompressedTensorsW4A4Nvfp4MoEMethod does not yet "
|
||||
@@ -221,7 +239,16 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
super().__init__(moe)
|
||||
self.group_size = 16
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
if use_marlin:
|
||||
if is_fp4_marlin_supported():
|
||||
self.nvfp4_backend = NvFp4MoeBackend.MARLIN
|
||||
else:
|
||||
raise ValueError(
|
||||
"Marlin FP4 MoE kernel requested but not ",
|
||||
"supported on current platform.",
|
||||
)
|
||||
else:
|
||||
self.nvfp4_backend = select_nvfp4_moe_backend()
|
||||
self.use_global_sf = is_global_sf_supported_for_nvfp4_backend(
|
||||
self.nvfp4_backend
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user