From ef9cd023a96b5fee5c97d7bb5a25ee9330c031af Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 13:44:42 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20unpack=5Fue4m3=5Fu32=20=E2=80=94=20uint3?= =?UTF-8?q?2=20lacks=20CUDA=20bitwise=20ops,=20use=20int32?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyTorch doesn't implement bitwise_and/shift for UInt32 on CUDA. Cast to int32 first, then extract bytes, then uint8 → view float8. --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 30a37a65..a5424840 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -32,14 +32,19 @@ def unpack_ue4m3_u32(x_u32): Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32]. Must use bit reinterpret (view), NOT value cast (to) — byte 0x3F is the float8 whose bits are 0x3F (~0.984), NOT the integer 63. + + CUDA doesn't implement bitwise ops on uint32, so we cast to int32 first. """ - x_u32 = x_u32.contiguous() - M, N = x_u32.shape - # Extract 4 bytes as uint8, then bit-reinterpret to float8_e4m3fn - b0 = (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) - b1 = ((x_u32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) - b2 = ((x_u32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) - b3 = ((x_u32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + # CUDA uint32 lacks bitwise ops — use int32 + x_i32 = x_u32.to(torch.int32) + M, N = x_i32.shape + + # Extract 4 bytes, cast to uint8, then bit-reinterpret to float8_e4m3fn + b0 = (x_i32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + b1 = ((x_i32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + b2 = ((x_i32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + b3 = ((x_i32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) + # Interleave into (M, N*4) out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device) out[:, 0::4] = b0