[Bugfix] Fix marlin nvfp4 rescaling (#37502)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
This commit is contained in:
@@ -43,9 +43,9 @@ def _nvfp4_compute_scale_factor(
|
||||
ws_float = marlin_scales.float() * (2**7)
|
||||
nonzero_mask = ws_float > 0
|
||||
if nonzero_mask.any():
|
||||
min_val = ws_float[nonzero_mask].min()
|
||||
if min_val < 2:
|
||||
sf = (2 / min_val).log2().ceil().exp2()
|
||||
max_val = ws_float[nonzero_mask].max()
|
||||
if max_val < 448 * (2**7):
|
||||
sf = (448 * (2**7) / max_val).log2().floor().exp2()
|
||||
return sf.item()
|
||||
return 1.0
|
||||
|
||||
@@ -105,7 +105,9 @@ def nvfp4_marlin_process_scales(
|
||||
if scale_factor > 1.0:
|
||||
marlin_scales = (marlin_scales.float() * scale_factor).to(torch.half)
|
||||
|
||||
marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1
|
||||
marlin_scales = marlin_scales * (2**7)
|
||||
marlin_scales[marlin_scales < 2] = 0
|
||||
marlin_scales = marlin_scales.view(torch.int16) << 1
|
||||
marlin_scales = marlin_scales.view(torch.float8_e4m3fn)
|
||||
marlin_scales = marlin_scales[:, 1::2].contiguous()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user