fix: UE8M0→float32 reinterpret in fold_global_scale (Bug #7)

This commit is contained in:
2026-05-11 19:40:01 +00:00
parent 47621bb990
commit 2cd86ff5e7

View File

@@ -224,7 +224,9 @@ def transform_nvfp4_weights_for_mega_moe(
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
if scale_2 is None:
return sf
sf_f32 = sf.to(torch.float32)
# UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent,
# NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3)
sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32)
if scale_2.dim() == 1:
scale_2 = scale_2.view(-1, 1, 1)
sf_f32 = sf_f32 * scale_2