From fbfeb54c9adca50b0b23b3f7455f5af8cdb69c03 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 05:52:33 +0000 Subject: [PATCH] Fix fold_global_scale: UE4M3 scales use .to(float32), not shift-by-23 Checkpoint stores float8_e4m3fn (standard NVFP4), not UE8M0. The shift-by-23 was misinterpreting E4M3 bytes as E8M0 exponents. --- deep_gemm/mega/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index c4940c6..ef9a4a3 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -224,9 +224,9 @@ def transform_nvfp4_weights_for_mega_moe( """Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3""" if scale_2 is None: return sf - # UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent, - # NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3) - sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32) + # UE4M3 → float32: checkpoint stores float8_e4m3fn (standard NVFP4 spec) + # NOT UE8M0 — shift-by-23 was wrong (Bug #7 fix: data IS E4M3, not E8M0) + sf_f32 = sf.to(torch.float32) if scale_2.dim() == 1: scale_2 = scale_2.view(-1, 1, 1) sf_f32 = sf_f32 * scale_2