From 57c629ed1b93fe9919227272152c08518d758cce Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 09:45:54 +0000 Subject: [PATCH] fix: cast to int32 before >> 23 (uint32 doesn't support right-shift) --- deep_gemm/mega/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index cd17be6..f30294e 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -175,8 +175,9 @@ def transform_nvfp4_weights_for_mega_moe( sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2]) # Convert to UE8M0: extract exponent byte from float32 bit pattern # UE8M0: uint8 = float32_bits[31:23] (8 exponent bits) - sf_bits = sf_merged.view(torch.uint32) # reinterpret float32 bits as uint32 - sf_ue8m0 = (sf_bits >> 23).to(torch.uint8) + # Note: PyTorch doesn't support >> on uint32, cast to int32 first + sf_bits = sf_merged.view(torch.int32) # reinterpret float32 bits as int32 + sf_ue8m0 = ((sf_bits >> 23) & 0xFF).to(torch.uint8) return sf_ue8m0 l1_sf_32 = merge_block16_to_block32(l1_sf)