fix: UE8M0→float32 reinterpret in fold_global_scale (Bug #7)
This commit is contained in:
@@ -224,7 +224,9 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
|
||||
if scale_2 is None:
|
||||
return sf
|
||||
sf_f32 = sf.to(torch.float32)
|
||||
# UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent,
|
||||
# NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3)
|
||||
sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32)
|
||||
if scale_2.dim() == 1:
|
||||
scale_2 = scale_2.view(-1, 1, 1)
|
||||
sf_f32 = sf_f32 * scale_2
|
||||
|
||||
Reference in New Issue
Block a user