fix: cast to int32 before >> 23 (uint32 doesn't support right-shift)

This commit is contained in:
2026-05-11 09:45:54 +00:00
parent 6d7231a50e
commit 57c629ed1b

View File

@@ -175,8 +175,9 @@ def transform_nvfp4_weights_for_mega_moe(
sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2])
# Convert to UE8M0: extract exponent byte from float32 bit pattern
# UE8M0: uint8 = float32_bits[31:23] (8 exponent bits)
sf_bits = sf_merged.view(torch.uint32) # reinterpret float32 bits as uint32
sf_ue8m0 = (sf_bits >> 23).to(torch.uint8)
# Note: PyTorch doesn't support >> on uint32, cast to int32 first
sf_bits = sf_merged.view(torch.int32) # reinterpret float32 bits as int32
sf_ue8m0 = ((sf_bits >> 23) & 0xFF).to(torch.uint8)
return sf_ue8m0
l1_sf_32 = merge_block16_to_block32(l1_sf)