diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 16d2c64a8..e4a2ab413 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -27,7 +27,44 @@ def is_fp4_marlin_supported(): return current_platform.has_device_capability(75) -def nvfp4_marlin_process_scales(marlin_scales): +def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float: + """Compute the power-of-2 scale_factor needed so that all non-zero + values in marlin_scales * 2^7 are >= 2 after rescaling. + Returns a Python float (power of 2, >= 1.0).""" + ws_float = marlin_scales.float() * (2**7) + nonzero_mask = ws_float > 0 + if nonzero_mask.any(): + min_val = ws_float[nonzero_mask].min() + if min_val < 2: + sf = (2 / min_val).log2().ceil().exp2() + assert (ws_float[nonzero_mask] * sf <= 448 * (2**7)).all(), ( + "NVFP4 scale dynamic range too large for rescaling" + ) + return sf.item() + return 1.0 + + +def nvfp4_marlin_process_scales( + marlin_scales: torch.Tensor, + scale_factor: float | None = None, +) -> tuple[torch.Tensor, float]: + """Process NVFP4 weight scales into the special S0E5M3 format for Marlin. + + Args: + marlin_scales: Weight scales tensor in half precision, already + permuted for the Marlin kernel layout. + scale_factor: Optional power-of-2 rescaling factor. If None, the + factor is computed automatically so that every non-zero scale + satisfies ``scale * 2^7 >= 2`` (i.e., the MSB of the S0E5M3 + representation is always 1). When provided (e.g., for MoE + layers where all experts must share the same factor), the + given value is used directly. The caller is responsible for + dividing ``global_scale`` by the returned ``scale_factor`` to + preserve numerical correctness. + + Returns: + A tuple of (processed_scales, scale_factor). + """ if not (marlin_scales >= 0).all(): logger.warning_once( "NVFP4 Marlin assumes the scales to be >=0, but has encountered " @@ -51,11 +88,21 @@ def nvfp4_marlin_process_scales(marlin_scales): # when weight_scale > 0. This allows us to have an exponent bias # closer to zero after dequantization. + # Rescale weight_scale so that all non-zero values have MSB=1 + # after multiplying by 2^7 (i.e., weight_scale * 2^7 >= 2). + # This is needed for models whose E4M3 scales were not normalized + # to fully utilize the E4M3 dynamic range (e.g., global_scale=1). + # The caller must compensate by dividing global_scale by scale_factor. + if scale_factor is None: + scale_factor = _nvfp4_compute_scale_factor(marlin_scales) + if scale_factor > 1.0: + marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half) + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 marlin_scales = marlin_scales.view(torch.float8_e4m3fn) marlin_scales = marlin_scales[:, 1::2].contiguous() - return marlin_scales + return marlin_scales, scale_factor def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None): @@ -200,11 +247,12 @@ def prepare_fp4_layer_for_marlin( ) if is_nvfp4: - weight_scale = nvfp4_marlin_process_scales(weight_scale) + weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) weight_global_scale = layer.weight_global_scale.to(param_dtype) weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale) + weight_global_scale = weight_global_scale / scale_factor layer.weight_global_scale = torch.nn.Parameter( weight_global_scale, requires_grad=False ) @@ -303,6 +351,10 @@ def prepare_nvfp4_moe_layer_for_marlin( else: size_n, size_k = K, N + # All experts share one global_scale, so compute the max + # scale_factor across all experts first, then apply uniformly. + combined_scale_factor = _nvfp4_compute_scale_factor(scales) + for i in range(E): scale = scales[i].T marlin_scales = marlin_permute_scales( @@ -312,11 +364,14 @@ def prepare_nvfp4_moe_layer_for_marlin( group_size=GROUP_SIZE, is_a_8bit=is_a_8bit, ) - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + marlin_scales, _ = nvfp4_marlin_process_scales( + marlin_scales, scale_factor=combined_scale_factor + ) tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) g_scales = nvfp4_marlin_process_global_scale(g_scales) + g_scales = g_scales / combined_scale_factor return scales, g_scales w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13") @@ -394,6 +449,11 @@ def prepare_moe_fp4_layer_for_marlin( else: size_n, size_k = k, n + # For NVFP4: compute unified scale_factor across all experts + combined_scale_factor = None + if is_nvfp4: + combined_scale_factor = _nvfp4_compute_scale_factor(scales) + for i in range(e): scale = scales[i].T @@ -405,7 +465,9 @@ def prepare_moe_fp4_layer_for_marlin( is_a_8bit=is_a_8bit, ) if is_nvfp4: - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + marlin_scales, _ = nvfp4_marlin_process_scales( + marlin_scales, scale_factor=combined_scale_factor + ) else: marlin_scales = mxfp4_marlin_process_scales( marlin_scales, input_dtype=input_dtype @@ -417,7 +479,9 @@ def prepare_moe_fp4_layer_for_marlin( setattr(layer, name + "_weight_scale", scales) if is_nvfp4: + assert combined_scale_factor is not None global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = global_scale / combined_scale_factor global_scale = torch.nn.Parameter(global_scale, requires_grad=False) setattr(layer, name + "_weight_scale_2", global_scale) @@ -488,9 +552,10 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None): group_size=group_size, is_a_8bit=is_a_8bit, ) - marlin_scales = nvfp4_marlin_process_scales(marlin_scales) + marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales) global_scale = nvfp4_marlin_process_global_scale(global_scale) + global_scale = global_scale / scale_factor return weight_ref.T, marlin_qweight, marlin_scales, global_scale