From 3bcc0ac057cd97c7897fe1ea3b4f0facf45c3658 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 12:59:20 +0000 Subject: [PATCH] fix: unpack_ue4m3_u32 was value-casting instead of bit-reinterpreting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Byte 0x3F was becoming float8(63.0) instead of the float8 whose bit pattern IS 0x3F (~0.984). Pack uses .view() (correct), unpack used .to() (wrong) — they were not inverses. This corrupted every activation scale fed to the L1 GEMM while weight scales were fine. --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 881ac113..87a50976 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -30,15 +30,18 @@ def unpack_ue4m3_u32(x_u32): """Unpack uint32 packed UE4M3 scales to float8_e4m3fn. Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32]. + Must use bit reinterpret (view), NOT value cast (to) — byte 0x3F is the float8 + whose bits are 0x3F (~0.984), NOT the integer 63. """ x_u32 = x_u32.contiguous() M, N = x_u32.shape + # Extract 4 bytes as uint8, then bit-reinterpret to float8_e4m3fn + b0 = (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + b1 = ((x_u32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + b2 = ((x_u32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + b3 = ((x_u32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + # Interleave into (M, N*4) out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device) - # Vectorized unpack: extract 4 bytes from each uint32 - b0 = (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) - b1 = ((x_u32 >> 8) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) - b2 = ((x_u32 >> 16) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) - b3 = ((x_u32 >> 24) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) out[:, 0::4] = b0 out[:, 1::4] = b1 out[:, 2::4] = b2