[Bugfix] Fix marlin nvfp4 rescaling (#37502)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
Jinzhen Lin
2026-04-07 23:57:17 +08:00
committed by GitHub
parent 96b5004b71
commit 7310555482

View File

@@ -43,9 +43,9 @@ def _nvfp4_compute_scale_factor(
ws_float = marlin_scales.float() * (2**7)
nonzero_mask = ws_float > 0
if nonzero_mask.any():
min_val = ws_float[nonzero_mask].min()
if min_val < 2:
sf = (2 / min_val).log2().ceil().exp2()
max_val = ws_float[nonzero_mask].max()
if max_val < 448 * (2**7):
sf = (448 * (2**7) / max_val).log2().floor().exp2()
return sf.item()
return 1.0
@@ -105,7 +105,9 @@ def nvfp4_marlin_process_scales(
if scale_factor > 1.0:
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
marlin_scales = marlin_scales * (2**7)
marlin_scales[marlin_scales < 2] = 0
marlin_scales = marlin_scales.view(torch.int16) << 1
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
marlin_scales = marlin_scales[:, 1::2].contiguous()