From 2cd86ff5e73afd5fcfe459262b69a63e24f7a45e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 19:40:01 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20UE8M0=E2=86=92float32=20reinterpret=20in?= =?UTF-8?q?=20fold=5Fglobal=5Fscale=20(Bug=20#7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deep_gemm/mega/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index dd455f2..7808432 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -224,7 +224,9 @@ def transform_nvfp4_weights_for_mega_moe( """Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3""" if scale_2 is None: return sf - sf_f32 = sf.to(torch.float32) + # UE8M0 → float32: must reinterpret raw uint8 bits as IEEE 754 exponent, + # NOT cast float8_e4m3fn → float32 (Bug #7: E8M0 bytes misinterpreted as E4M3) + sf_f32 = (sf.view(torch.uint8).to(torch.int32) << 23).view(torch.float32) if scale_2.dim() == 1: scale_2 = scale_2.view(-1, 1, 1) sf_f32 = sf_f32 * scale_2