fix: unpack_ue4m3_u32 was value-casting instead of bit-reinterpreting

Byte 0x3F was becoming float8(63.0) instead of the float8 whose bit
pattern IS 0x3F (~0.984). Pack uses .view() (correct), unpack used
.to() (wrong) — they were not inverses. This corrupted every activation
scale fed to the L1 GEMM while weight scales were fine.
This commit is contained in:
2026-05-14 12:59:20 +00:00
parent 8b7fa0c91e
commit 3bcc0ac057

View File

@@ -30,15 +30,18 @@ def unpack_ue4m3_u32(x_u32):
"""Unpack uint32 packed UE4M3 scales to float8_e4m3fn.
Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32].
Must use bit reinterpret (view), NOT value cast (to) — byte 0x3F is the float8
whose bits are 0x3F (~0.984), NOT the integer 63.
"""
x_u32 = x_u32.contiguous()
M, N = x_u32.shape
# Extract 4 bytes as uint8, then bit-reinterpret to float8_e4m3fn
b0 = (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
b1 = ((x_u32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
b2 = ((x_u32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
b3 = ((x_u32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
# Interleave into (M, N*4)
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
# Vectorized unpack: extract 4 bytes from each uint32
b0 = (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
b1 = ((x_u32 >> 8) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
b2 = ((x_u32 >> 16) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
b3 = ((x_u32 >> 24) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
out[:, 0::4] = b0
out[:, 1::4] = b1
out[:, 2::4] = b2