fix: reinterpret float32 bits as uint32 before >> 23 for UE8M0
This commit is contained in:
@@ -173,10 +173,10 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
sf_f32 = sf.to(torch.float32)
|
||||
# Take max of adjacent pairs
|
||||
sf_merged = torch.maximum(sf_f32[..., 0::2], sf_f32[..., 1::2])
|
||||
# Convert to UE8M0: 2^(floor(log2(v)) - 127 + 127) = extract exponent byte
|
||||
# UE8M0 encoding: uint8 = float32_exponent_bits >> 23
|
||||
sf_u32 = sf_merged.view(torch.uint32)
|
||||
sf_ue8m0 = (sf_u32 >> 23).to(torch.uint8)
|
||||
# Convert to UE8M0: extract exponent byte from float32 bit pattern
|
||||
# UE8M0: uint8 = float32_bits[31:23] (8 exponent bits)
|
||||
sf_bits = sf_merged.view(torch.uint32) # reinterpret float32 bits as uint32
|
||||
sf_ue8m0 = (sf_bits >> 23).to(torch.uint8)
|
||||
return sf_ue8m0
|
||||
|
||||
l1_sf_32 = merge_block16_to_block32(l1_sf)
|
||||
|
||||
Reference in New Issue
Block a user