46 lines
2.1 KiB
Python
46 lines
2.1 KiB
Python
"""Check if float8_e4m3fn (signed) vs float_ue4m3 (unsigned) matters.
|
|
In the CUTLASS kernel, SF is float_ue4m3 (unsigned E4M3).
|
|
In our Python reference, we use .to(torch.float32) which interprets float8_e4m3fn (signed).
|
|
If the sign bit is set, signed and unsigned give different values.
|
|
"""
|
|
import torch
|
|
device = "cuda"
|
|
|
|
# Create some float8 values and compare signed vs unsigned interpretation
|
|
vals = torch.tensor([0x00, 0x3F, 0x7F, 0x80, 0xBF, 0xFF], dtype=torch.uint8, device=device)
|
|
|
|
# Signed interpretation (float8_e4m3fn)
|
|
signed = vals.view(torch.float8_e4m3fn).to(torch.float32)
|
|
print("Signed (float8_e4m3fn):", signed.tolist())
|
|
|
|
# Unsigned interpretation (float8_e4m3fnuz — unsigned zero)
|
|
# Actually, let's check if there IS an unsigned float8 type in PyTorch
|
|
print("Has float8_e4m3fnuz:", hasattr(torch, 'float8_e4m3fnuz'))
|
|
|
|
# The key question: are SF values always positive?
|
|
# UE4M3 means the sign bit is NOT used — all values are positive.
|
|
# But if we read a UE4M3 byte as signed E4M3, bytes with bit 7 set
|
|
# would be interpreted as negative.
|
|
# Let's check: for valid UE4M3 values, is bit 7 ever set?
|
|
# E4M3 range: 0 to 448. The encoding uses the sign bit for actual sign.
|
|
# UE4M3: the sign bit is always 0 (positive only, range 0 to 448).
|
|
# So reading UE4M3 as signed E4M3 should give the same result
|
|
# as long as the sign bit is 0.
|
|
|
|
# Check our actual SF data
|
|
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1
|
|
torch.manual_seed(42)
|
|
x = torch.randn(1, 32, device=device) * 2.0
|
|
x_fp4, x_sf = _quantize_to_e2m1(x.float())
|
|
sf_bytes = x_sf.view(torch.uint8)
|
|
print(f"\nSF bytes: {sf_bytes.flatten()[:16].tolist()}")
|
|
print(f"Any byte with bit 7 set (>= 128): {(sf_bytes >= 128).any().item()}")
|
|
print(f"SF as signed float: {x_sf.to(torch.float32).flatten()[:8].tolist()}")
|
|
|
|
# Check: does CUTLASS treat SF as signed or unsigned?
|
|
# The C++ type is cutlass::float_ue4m3_t
|
|
# In the CU file we use: const cutlass::float_ue4m3_t* src
|
|
# But PyTorch passes float8_e4m3fn (signed)
|
|
# These have the same bit pattern for positive values
|
|
# but DIFFERENT bit patterns for values where the sign bit is set
|