fix: unpack_ue4m3_u32 — uint32 lacks CUDA bitwise ops, use int32
PyTorch doesn't implement bitwise_and/shift for UInt32 on CUDA. Cast to int32 first, then extract bytes, then uint8 → view float8.
This commit is contained in:
@@ -32,14 +32,19 @@ def unpack_ue4m3_u32(x_u32):
|
||||
Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32].
|
||||
Must use bit reinterpret (view), NOT value cast (to) — byte 0x3F is the float8
|
||||
whose bits are 0x3F (~0.984), NOT the integer 63.
|
||||
|
||||
CUDA doesn't implement bitwise ops on uint32, so we cast to int32 first.
|
||||
"""
|
||||
x_u32 = x_u32.contiguous()
|
||||
M, N = x_u32.shape
|
||||
# Extract 4 bytes as uint8, then bit-reinterpret to float8_e4m3fn
|
||||
b0 = (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b1 = ((x_u32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b2 = ((x_u32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b3 = ((x_u32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
# CUDA uint32 lacks bitwise ops — use int32
|
||||
x_i32 = x_u32.to(torch.int32)
|
||||
M, N = x_i32.shape
|
||||
|
||||
# Extract 4 bytes, cast to uint8, then bit-reinterpret to float8_e4m3fn
|
||||
b0 = (x_i32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b1 = ((x_i32 >> 8) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b2 = ((x_i32 >> 16) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
b3 = ((x_i32 >> 24) & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)
|
||||
|
||||
# Interleave into (M, N*4)
|
||||
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
|
||||
out[:, 0::4] = b0
|
||||
|
||||
Reference in New Issue
Block a user