diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index dd455f2..7808432 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -224,7 +224,9 @@ def transform_nvfp4_weights_for_mega_moe( """Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3""" if scale_2 is None: return sf - sf_f32 = sf.to(torch.float32) + # UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent, + # NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3) + sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32) if scale_2.dim() == 1: scale_2 = scale_2.view(-1, 1, 1) sf_f32 = sf_f32 * scale_2