[Bugfix] Rescale NVFP4 weight scales to fix BF16 dequant underflow (#34577)

Signed-off-by: ricky-chaoju <ricky.chen@infinirc.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Chao-Ju Chen
2026-03-18 04:48:42 +08:00
committed by GitHub
parent 1204cf0a9d
commit 245758992e

View File

@@ -27,7 +27,44 @@ def is_fp4_marlin_supported():
return current_platform.has_device_capability(75)
def nvfp4_marlin_process_scales(marlin_scales):
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
"""Compute the power-of-2 scale_factor needed so that all non-zero
values in marlin_scales * 2^7 are >= 2 after rescaling.
Returns a Python float (power of 2, >= 1.0)."""
ws_float = marlin_scales.float() * (2**7)
nonzero_mask = ws_float > 0
if nonzero_mask.any():
min_val = ws_float[nonzero_mask].min()
if min_val < 2:
sf = (2 / min_val).log2().ceil().exp2()
assert (ws_float[nonzero_mask] * sf <= 448 * (2**7)).all(), (
"NVFP4 scale dynamic range too large for rescaling"
)
return sf.item()
return 1.0
def nvfp4_marlin_process_scales(
marlin_scales: torch.Tensor,
scale_factor: float | None = None,
) -> tuple[torch.Tensor, float]:
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
Args:
marlin_scales: Weight scales tensor in half precision, already
permuted for the Marlin kernel layout.
scale_factor: Optional power-of-2 rescaling factor. If None, the
factor is computed automatically so that every non-zero scale
satisfies ``scale * 2^7 >= 2`` (i.e., the MSB of the S0E5M3
representation is always 1). When provided (e.g., for MoE
layers where all experts must share the same factor), the
given value is used directly. The caller is responsible for
dividing ``global_scale`` by the returned ``scale_factor`` to
preserve numerical correctness.
Returns:
A tuple of (processed_scales, scale_factor).
"""
if not (marlin_scales >= 0).all():
logger.warning_once(
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
@@ -51,11 +88,21 @@ def nvfp4_marlin_process_scales(marlin_scales):
# when weight_scale > 0. This allows us to have an exponent bias
# closer to zero after dequantization.
# Rescale weight_scale so that all non-zero values have MSB=1
# after multiplying by 2^7 (i.e., weight_scale * 2^7 >= 2).
# This is needed for models whose E4M3 scales were not normalized
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
# The caller must compensate by dividing global_scale by scale_factor.
if scale_factor is None:
scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
if scale_factor > 1.0:
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
marlin_scales = marlin_scales[:, 1::2].contiguous()
return marlin_scales
return marlin_scales, scale_factor
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
@@ -200,11 +247,12 @@ def prepare_fp4_layer_for_marlin(
)
if is_nvfp4:
weight_scale = nvfp4_marlin_process_scales(weight_scale)
weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
weight_global_scale = layer.weight_global_scale.to(param_dtype)
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
weight_global_scale = weight_global_scale / scale_factor
layer.weight_global_scale = torch.nn.Parameter(
weight_global_scale, requires_grad=False
)
@@ -303,6 +351,10 @@ def prepare_nvfp4_moe_layer_for_marlin(
else:
size_n, size_k = K, N
# All experts share one global_scale, so compute the max
# scale_factor across all experts first, then apply uniformly.
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
for i in range(E):
scale = scales[i].T
marlin_scales = marlin_permute_scales(
@@ -312,11 +364,14 @@ def prepare_nvfp4_moe_layer_for_marlin(
group_size=GROUP_SIZE,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
marlin_scales, _ = nvfp4_marlin_process_scales(
marlin_scales, scale_factor=combined_scale_factor
)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
g_scales = nvfp4_marlin_process_global_scale(g_scales)
g_scales = g_scales / combined_scale_factor
return scales, g_scales
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
@@ -394,6 +449,11 @@ def prepare_moe_fp4_layer_for_marlin(
else:
size_n, size_k = k, n
# For NVFP4: compute unified scale_factor across all experts
combined_scale_factor = None
if is_nvfp4:
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
for i in range(e):
scale = scales[i].T
@@ -405,7 +465,9 @@ def prepare_moe_fp4_layer_for_marlin(
is_a_8bit=is_a_8bit,
)
if is_nvfp4:
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
marlin_scales, _ = nvfp4_marlin_process_scales(
marlin_scales, scale_factor=combined_scale_factor
)
else:
marlin_scales = mxfp4_marlin_process_scales(
marlin_scales, input_dtype=input_dtype
@@ -417,7 +479,9 @@ def prepare_moe_fp4_layer_for_marlin(
setattr(layer, name + "_weight_scale", scales)
if is_nvfp4:
assert combined_scale_factor is not None
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = global_scale / combined_scale_factor
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
setattr(layer, name + "_weight_scale_2", global_scale)
@@ -488,9 +552,10 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
group_size=group_size,
is_a_8bit=is_a_8bit,
)
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
global_scale = nvfp4_marlin_process_global_scale(global_scale)
global_scale = global_scale / scale_factor
return weight_ref.T, marlin_qweight, marlin_scales, global_scale