[Bugfix] Rescale NVFP4 weight scales to fix BF16 dequant underflow (#34577)
Signed-off-by: ricky-chaoju <ricky.chen@infinirc.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -27,7 +27,44 @@ def is_fp4_marlin_supported():
|
|||||||
return current_platform.has_device_capability(75)
|
return current_platform.has_device_capability(75)
|
||||||
|
|
||||||
|
|
||||||
def nvfp4_marlin_process_scales(marlin_scales):
|
def _nvfp4_compute_scale_factor(marlin_scales: torch.Tensor) -> float:
|
||||||
|
"""Compute the power-of-2 scale_factor needed so that all non-zero
|
||||||
|
values in marlin_scales * 2^7 are >= 2 after rescaling.
|
||||||
|
Returns a Python float (power of 2, >= 1.0)."""
|
||||||
|
ws_float = marlin_scales.float() * (2**7)
|
||||||
|
nonzero_mask = ws_float > 0
|
||||||
|
if nonzero_mask.any():
|
||||||
|
min_val = ws_float[nonzero_mask].min()
|
||||||
|
if min_val < 2:
|
||||||
|
sf = (2 / min_val).log2().ceil().exp2()
|
||||||
|
assert (ws_float[nonzero_mask] * sf <= 448 * (2**7)).all(), (
|
||||||
|
"NVFP4 scale dynamic range too large for rescaling"
|
||||||
|
)
|
||||||
|
return sf.item()
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def nvfp4_marlin_process_scales(
|
||||||
|
marlin_scales: torch.Tensor,
|
||||||
|
scale_factor: float | None = None,
|
||||||
|
) -> tuple[torch.Tensor, float]:
|
||||||
|
"""Process NVFP4 weight scales into the special S0E5M3 format for Marlin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
marlin_scales: Weight scales tensor in half precision, already
|
||||||
|
permuted for the Marlin kernel layout.
|
||||||
|
scale_factor: Optional power-of-2 rescaling factor. If None, the
|
||||||
|
factor is computed automatically so that every non-zero scale
|
||||||
|
satisfies ``scale * 2^7 >= 2`` (i.e., the MSB of the S0E5M3
|
||||||
|
representation is always 1). When provided (e.g., for MoE
|
||||||
|
layers where all experts must share the same factor), the
|
||||||
|
given value is used directly. The caller is responsible for
|
||||||
|
dividing ``global_scale`` by the returned ``scale_factor`` to
|
||||||
|
preserve numerical correctness.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (processed_scales, scale_factor).
|
||||||
|
"""
|
||||||
if not (marlin_scales >= 0).all():
|
if not (marlin_scales >= 0).all():
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
||||||
@@ -51,11 +88,21 @@ def nvfp4_marlin_process_scales(marlin_scales):
|
|||||||
# when weight_scale > 0. This allows us to have an exponent bias
|
# when weight_scale > 0. This allows us to have an exponent bias
|
||||||
# closer to zero after dequantization.
|
# closer to zero after dequantization.
|
||||||
|
|
||||||
|
# Rescale weight_scale so that all non-zero values have MSB=1
|
||||||
|
# after multiplying by 2^7 (i.e., weight_scale * 2^7 >= 2).
|
||||||
|
# This is needed for models whose E4M3 scales were not normalized
|
||||||
|
# to fully utilize the E4M3 dynamic range (e.g., global_scale=1).
|
||||||
|
# The caller must compensate by dividing global_scale by scale_factor.
|
||||||
|
if scale_factor is None:
|
||||||
|
scale_factor = _nvfp4_compute_scale_factor(marlin_scales)
|
||||||
|
if scale_factor > 1.0:
|
||||||
|
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
|
||||||
|
|
||||||
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
|
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
|
||||||
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
|
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
|
||||||
marlin_scales = marlin_scales[:, 1::2].contiguous()
|
marlin_scales = marlin_scales[:, 1::2].contiguous()
|
||||||
|
|
||||||
return marlin_scales
|
return marlin_scales, scale_factor
|
||||||
|
|
||||||
|
|
||||||
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
|
def mxfp4_marlin_process_scales(marlin_scales, input_dtype=None):
|
||||||
@@ -200,11 +247,12 @@ def prepare_fp4_layer_for_marlin(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
weight_scale = nvfp4_marlin_process_scales(weight_scale)
|
weight_scale, scale_factor = nvfp4_marlin_process_scales(weight_scale)
|
||||||
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
||||||
|
|
||||||
weight_global_scale = layer.weight_global_scale.to(param_dtype)
|
weight_global_scale = layer.weight_global_scale.to(param_dtype)
|
||||||
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
|
weight_global_scale = nvfp4_marlin_process_global_scale(weight_global_scale)
|
||||||
|
weight_global_scale = weight_global_scale / scale_factor
|
||||||
layer.weight_global_scale = torch.nn.Parameter(
|
layer.weight_global_scale = torch.nn.Parameter(
|
||||||
weight_global_scale, requires_grad=False
|
weight_global_scale, requires_grad=False
|
||||||
)
|
)
|
||||||
@@ -303,6 +351,10 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
|||||||
else:
|
else:
|
||||||
size_n, size_k = K, N
|
size_n, size_k = K, N
|
||||||
|
|
||||||
|
# All experts share one global_scale, so compute the max
|
||||||
|
# scale_factor across all experts first, then apply uniformly.
|
||||||
|
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
||||||
|
|
||||||
for i in range(E):
|
for i in range(E):
|
||||||
scale = scales[i].T
|
scale = scales[i].T
|
||||||
marlin_scales = marlin_permute_scales(
|
marlin_scales = marlin_permute_scales(
|
||||||
@@ -312,11 +364,14 @@ def prepare_nvfp4_moe_layer_for_marlin(
|
|||||||
group_size=GROUP_SIZE,
|
group_size=GROUP_SIZE,
|
||||||
is_a_8bit=is_a_8bit,
|
is_a_8bit=is_a_8bit,
|
||||||
)
|
)
|
||||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||||
|
marlin_scales, scale_factor=combined_scale_factor
|
||||||
|
)
|
||||||
tensor_list.append(marlin_scales)
|
tensor_list.append(marlin_scales)
|
||||||
|
|
||||||
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
|
||||||
g_scales = nvfp4_marlin_process_global_scale(g_scales)
|
g_scales = nvfp4_marlin_process_global_scale(g_scales)
|
||||||
|
g_scales = g_scales / combined_scale_factor
|
||||||
return scales, g_scales
|
return scales, g_scales
|
||||||
|
|
||||||
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
|
w13_scale, w13_scale_2 = premute_scales(w13_scale, w13_scale_2, "w13")
|
||||||
@@ -394,6 +449,11 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
else:
|
else:
|
||||||
size_n, size_k = k, n
|
size_n, size_k = k, n
|
||||||
|
|
||||||
|
# For NVFP4: compute unified scale_factor across all experts
|
||||||
|
combined_scale_factor = None
|
||||||
|
if is_nvfp4:
|
||||||
|
combined_scale_factor = _nvfp4_compute_scale_factor(scales)
|
||||||
|
|
||||||
for i in range(e):
|
for i in range(e):
|
||||||
scale = scales[i].T
|
scale = scales[i].T
|
||||||
|
|
||||||
@@ -405,7 +465,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
is_a_8bit=is_a_8bit,
|
is_a_8bit=is_a_8bit,
|
||||||
)
|
)
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
marlin_scales, _ = nvfp4_marlin_process_scales(
|
||||||
|
marlin_scales, scale_factor=combined_scale_factor
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
marlin_scales = mxfp4_marlin_process_scales(
|
marlin_scales = mxfp4_marlin_process_scales(
|
||||||
marlin_scales, input_dtype=input_dtype
|
marlin_scales, input_dtype=input_dtype
|
||||||
@@ -417,7 +479,9 @@ def prepare_moe_fp4_layer_for_marlin(
|
|||||||
setattr(layer, name + "_weight_scale", scales)
|
setattr(layer, name + "_weight_scale", scales)
|
||||||
|
|
||||||
if is_nvfp4:
|
if is_nvfp4:
|
||||||
|
assert combined_scale_factor is not None
|
||||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||||
|
global_scale = global_scale / combined_scale_factor
|
||||||
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
global_scale = torch.nn.Parameter(global_scale, requires_grad=False)
|
||||||
setattr(layer, name + "_weight_scale_2", global_scale)
|
setattr(layer, name + "_weight_scale_2", global_scale)
|
||||||
|
|
||||||
@@ -488,9 +552,10 @@ def rand_marlin_weight_nvfp4_like(weight, group_size, input_dtype=None):
|
|||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
is_a_8bit=is_a_8bit,
|
is_a_8bit=is_a_8bit,
|
||||||
)
|
)
|
||||||
marlin_scales = nvfp4_marlin_process_scales(marlin_scales)
|
marlin_scales, scale_factor = nvfp4_marlin_process_scales(marlin_scales)
|
||||||
|
|
||||||
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
global_scale = nvfp4_marlin_process_global_scale(global_scale)
|
||||||
|
global_scale = global_scale / scale_factor
|
||||||
|
|
||||||
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
return weight_ref.T, marlin_qweight, marlin_scales, global_scale
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user