[Bugfix] Rescale NVFP4 weight scales to fix BF16 dequant underflow (#34577)
Signed-off-by: ricky-chaoju <ricky.chen@infinirc.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -27,7 +27,44 @@ def is_fp4_marlin_supported():
|
||||
return current_platform.has_device_capability(75)
|
||||
|
||||
|
||||
def nvfp4_marlin_process_scales(marlin_scales):
|
||||
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
|
||||
"""Compute the power-of-2 scale_factor needed so that all non-zero
|
||||
values in marlin_scales * 2^7 are >= 2 after rescaling.
|
||||
Returns a Python float (power of 2, >= 1.0)."""
|
||||
ws_float = marlin_scales.float() * (2**7)
|
||||
nonzero_mask = ws_float > 0
|
||||
if nonzero_mask.any():
|
||||
min_val = ws_float[nonzero_mask].min()
|
||||
if min_val < 2:
|
||||
sf = (2 / min_val).log2().ceil().exp2()
|
||||
assert (ws_float[nonzero_mask] * sf <= 448 * (2**7)).all(), (
|
||||
"NVFP4 scale dynamic range too large for rescaling"
|
||||
)
|
||||
return sf.item()
|
||||
return 1.0
|
||||
|
||||
|
||||
def nvfp4_marlin_process_scales(
|
||||
marlin_scales: torch.Tensor,
|
||||
scale_factor: float | None = None,
|
||||
) -> tuple[torch.Tensor, float]:
|
||||
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
|
||||
|
||||
Args:
|
||||
marlin_scales: Weight scales tensor in half precision, already
|
||||
permuted for the Marlin kernel layout.
|
||||
scale_factor: Optional power-of-2 rescaling factor. If None, the
|
||||
factor is computed automatically so that every non-zero scale
|
||||
satisfies ``scale * 2^7 >= 2`` (i.e., the MSB of the S0E5M3
|
||||
representation is always 1). When provided (e.g., for MoE
|
||||
layers where all experts must share the same factor), the
|
||||
given value is used directly. The caller is responsible for
|
||||
dividing ``global_scale`` by the returned ``scale_factor`` to
|
||||
preserve numerical correctness.
|
||||
|
||||
Returns:
|
||||
A tuple of (processed_scales, scale_factor).
|
||||
"""
|
||||
if not (marlin_scales >= 0).all():
|
||||
logger.warning_once(
|
||||
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
||||
@@ -51,11 +88,21 @@ def nvfp4_marlin_process_scales(marlin_scales):
|
||||
# when weight_scale > 0. This allows us to have an exponent bias
|
||||
# closer to zero after dequantization.
|
||||
|
||||
# Rescale weight_scale so that all non-zero values have MSB=1
|
||||
# after multiplying by 2^7 (i.e., weight_scale * 2^7 >= 2).
|
||||
# This is needed for models whose E4M3 scales were not normalized
|
||||
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
|
||||
# The caller must compensate by dividing global_scale by scale_factor.
|
||||
if scale_factor is None:
|
||||
scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
|
||||
if scale_factor > 1.0:
|
||||
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
|
||||
|
||||
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
|
||||
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
|
||||
marlin_scales = marlin_scales[:, 1::2].contiguous()
|
||||
|
||||
return marlin_scales
|
||||
return marlin_scales, scale_factor
|
||||
|
||||
|
||||
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
|
||||
@@ -200,11 +247,12 @@ def prepare_fp4_layer_for_marlin(
|
||||
)
|
||||
|
||||
if is_nvfp4:
|
||||
weight_scale = nvfp4_marlin_process_scales(weight_scale)
|
||||
weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
weight_global_scale = layer.weight_global_scale.to(param_dtype)
|
||||
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
|
||||
weight_global_scale = weight_global_scale / scale_factor
|
||||
layer.weight_global_scale = torch.nn.Parameter(
|
||||
weight_global_scale, requires_grad=False
|
||||
)
|
||||
@@ -303,6 +351,10 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
else:
|
||||
size_n, size_k = K, N
|
||||
|
||||
# All experts share one global_scale, so compute the max
|
||||
# scale_factor across all experts first, then apply uniformly.
|
||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
||||
|
||||
for i in range(E):
|
||||
scale = scales[i].T
|
||||
marlin_scales = marlin_permute_scales(
|
||||
@@ -312,11 +364,14 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
||||
group_size=GROUP_SIZE,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||
marlin_scales, scale_factor=combined_scale_factor
|
||||
)
|
||||
tensor_list.append(marlin_scales)
|
||||
|
||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||
g_scales = nvfp4_marlin_process_global_scale(g_scales)
|
||||
g_scales = g_scales / combined_scale_factor
|
||||
return scales, g_scales
|
||||
|
||||
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
|
||||
@@ -394,6 +449,11 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
else:
|
||||
size_n, size_k = k, n
|
||||
|
||||
# For NVFP4: compute unified scale_factor across all experts
|
||||
combined_scale_factor = None
|
||||
if is_nvfp4:
|
||||
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
||||
|
||||
for i in range(e):
|
||||
scale = scales[i].T
|
||||
|
||||
@@ -405,7 +465,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if is_nvfp4:
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||
marlin_scales, scale_factor=combined_scale_factor
|
||||
)
|
||||
else:
|
||||
marlin_scales = mxfp4_marlin_process_scales(
|
||||
marlin_scales, input_dtype=input_dtype
|
||||
@@ -417,7 +479,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
||||
setattr(layer, name + "_weight_scale", scales)
|
||||
|
||||
if is_nvfp4:
|
||||
assert combined_scale_factor is not None
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = global_scale / combined_scale_factor
|
||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||
|
||||
@@ -488,9 +552,10 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
|
||||
group_size=group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
||||
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
|
||||
|
||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||
global_scale = global_scale / scale_factor
|
||||
|
||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user