Fix fold_global_scale: UE4M3 scales use .to(float32), not shift-by-23

Checkpoint stores float8_e4m3fn (standard NVFP4), not UE8M0.
The shift-by-23 was misinterpreting E4M3 bytes as E8M0 exponents.
This commit is contained in:
2026-05-12 05:52:33 +00:00
parent af092fa7ba
commit fbfeb54c9a

View File

@@ -224,9 +224,9 @@ def transform_nvfp4_weights_for_mega_moe(
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
if scale_2 is None:
return sf
# UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent,
# NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3)
sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32)
# UE4M3 → float32: checkpoint stores float8_e4m3fn (standard NVFP4 spec)
# NOT UE8M0 — shift-by-23 was wrong (Bug #7 fix: data IS E4M3, not E8M0)
sf_f32 = sf.to(torch.float32)
if scale_2.dim() == 1:
scale_2 = scale_2.view(-1, 1, 1)
sf_f32 = sf_f32 * scale_2