From 6d7231a50e64c98bccdca9f5fdfd1c3214350a6e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 09:42:03 +0000 Subject: [PATCH] fix: reinterpret float32 bits as uint32 before >> 23 for UE8M0 --- deep_gemm/mega/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 5417f0d..cd17be6 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -173,10 +173,10 @@ def transform_nvfp4_weights_for_mega_moe( sf_f32 = sf.to(torch.float32) # Take max of adjacent pairs sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2]) - # Convert to UE8M0: 2^(floor(log2(v)) - 127 + 127) = extract exponent byte - # UE8M0 encoding: uint8 = float32_exponent_bits >> 23 - sf_u32 = sf_merged.view(torch.uint32) - sf_ue8m0 = (sf_u32 >> 23).to(torch.uint8) + # Convert to UE8M0: extract exponent byte from float32 bit pattern + # UE8M0: uint8 = float32_bits[31:23] (8 exponent bits) + sf_bits = sf_merged.view(torch.uint32) # reinterpret float32 bits as uint32 + sf_ue8m0 = (sf_bits >> 23).to(torch.uint8) return sf_ue8m0 l1_sf_32 = merge_block16_to_block32(l1_sf)