Files
nvfp4-megamoe-kernel/tests/test_uniform_fp4.py
biondizzle 303b6a8993 cleanup: move useful tests to tests/, nuke stale debug tests
Kept (moved to tests/):
- test_uniform_fp4.py — proves GEMM math (72.0 = 1.5² × K)
- test_b_layout.py — proves B matrix column layout
- test_quick_rand.py — quick GEMM sanity check

Removed (stale SF remap debug artifacts):
- test_forward_map.py, test_gemm_sweep.py, test_m1_gemm.py
- test_minimal_gemm.py, test_rand_gemm.py, test_sf_check.py
- test_sf_remap.py, test_sf_signed.py, test_sf_layout_diag.cu
2026-05-16 02:14:37 +00:00

31 lines
1.3 KiB
Python

"""Test: uniform FP4 + uniform SF, different from all-ones.
If all E2M1 values are the same (e.g. value 3 = 1.5) and all SF=1.0,
then x = 1.5 for all elements, w = 1.5 for all elements.
GEMM output = (1.5^2) * K = 2.25 * 32 = 72.0 for every element.
"""
import torch, sys
sys.path.insert(0, 'src')
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
device = "cuda"
M, N, K = 1, 32, 32
# Create packed FP4 where every nibble = 3 (E2M1 value 1.5)
# Packing: (nibbles[..., 1] << 4) | nibbles[..., 0]
# For both nibbles = 3: byte = (3 << 4) | 3 = 0x33
byte_val = (3 << 4) | 3 # 0x33
x_fp4 = torch.full((M, K // 2), byte_val, dtype=torch.int8, device=device)
w_fp4 = torch.full((K // 2, N), byte_val, dtype=torch.int8, device=device)
# Uniform SF = 1.0
x_sf = torch.ones(M, K // 16, dtype=torch.float8_e4m3fn, device=device)
w_sf = torch.ones(K // 16, N, dtype=torch.float8_e4m3fn, device=device)
out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
# Reference: all x = 1.5, all w = 1.5, output = 1.5 * 1.5 * 32 = 72.0
print(f"NVFP4 output first 8: {out[0, :8].tolist()}")
print(f"Expected: 72.0 for all elements")
print(f"Actual mean: {out.float().mean().item():.4f}")
print(f"All same? {torch.allclose(out, out[0,0].expand_as(out), atol=0.01)}")