Files
nvfp4-megamoe-kernel/test_sf_signed.py

46 lines
2.1 KiB
Python

"""Check if float8_e4m3fn (signed) vs float_ue4m3 (unsigned) matters.
In the CUTLASS kernel, SF is float_ue4m3 (unsigned E4M3).
In our Python reference, we use .to(torch.float32) which interprets float8_e4m3fn (signed).
If the sign bit is set, signed and unsigned give different values.
"""
import torch
device = "cuda"
# Create some float8 values and compare signed vs unsigned interpretation
vals = torch.tensor([0x00, 0x3F, 0x7F, 0x80, 0xBF, 0xFF], dtype=torch.uint8, device=device)
# Signed interpretation (float8_e4m3fn)
signed = vals.view(torch.float8_e4m3fn).to(torch.float32)
print("Signed (float8_e4m3fn):", signed.tolist())
# Unsigned interpretation (float8_e4m3fnuz — unsigned zero)
# Actually, let's check if there IS an unsigned float8 type in PyTorch
print("Has float8_e4m3fnuz:", hasattr(torch, 'float8_e4m3fnuz'))
# The key question: are SF values always positive?
# UE4M3 means the sign bit is NOT used — all values are positive.
# But if we read a UE4M3 byte as signed E4M3, bytes with bit 7 set
# would be interpreted as negative.
# Let's check: for valid UE4M3 values, is bit 7 ever set?
# E4M3 range: 0 to 448. The encoding uses the sign bit for actual sign.
# UE4M3: the sign bit is always 0 (positive only, range 0 to 448).
# So reading UE4M3 as signed E4M3 should give the same result
# as long as the sign bit is 0.
# Check our actual SF data
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1
torch.manual_seed(42)
x = torch.randn(1, 32, device=device) * 2.0
x_fp4, x_sf = _quantize_to_e2m1(x.float())
sf_bytes = x_sf.view(torch.uint8)
print(f"\nSF bytes: {sf_bytes.flatten()[:16].tolist()}")
print(f"Any byte with bit 7 set (>= 128): {(sf_bytes >= 128).any().item()}")
print(f"SF as signed float: {x_sf.to(torch.float32).flatten()[:8].tolist()}")
# Check: does CUTLASS treat SF as signed or unsigned?
# The C++ type is cutlass::float_ue4m3_t
# In the CU file we use: const cutlass::float_ue4m3_t* src
# But PyTorch passes float8_e4m3fn (signed)
# These have the same bit pattern for positive values
# but DIFFERENT bit patterns for values where the sign bit is set