Kept (moved to tests/): - test_uniform_fp4.py — proves GEMM math (72.0 = 1.5² × K) - test_b_layout.py — proves B matrix column layout - test_quick_rand.py — quick GEMM sanity check Removed (stale SF remap debug artifacts): - test_forward_map.py, test_gemm_sweep.py, test_m1_gemm.py - test_minimal_gemm.py, test_rand_gemm.py, test_sf_check.py - test_sf_remap.py, test_sf_signed.py, test_sf_layout_diag.cu
31 lines
1.3 KiB
Python
31 lines
1.3 KiB
Python
"""Test: uniform FP4 + uniform SF, different from all-ones.
|
|
If all E2M1 values are the same (e.g. value 3 = 1.5) and all SF=1.0,
|
|
then x = 1.5 for all elements, w = 1.5 for all elements.
|
|
GEMM output = (1.5^2) * K = 2.25 * 32 = 72.0 for every element.
|
|
"""
|
|
import torch, sys
|
|
sys.path.insert(0, 'src')
|
|
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
|
|
|
|
device = "cuda"
|
|
M, N, K = 1, 32, 32
|
|
|
|
# Create packed FP4 where every nibble = 3 (E2M1 value 1.5)
|
|
# Packing: (nibbles[..., 1] << 4) | nibbles[..., 0]
|
|
# For both nibbles = 3: byte = (3 << 4) | 3 = 0x33
|
|
byte_val = (3 << 4) | 3 # 0x33
|
|
x_fp4 = torch.full((M, K // 2), byte_val, dtype=torch.int8, device=device)
|
|
w_fp4 = torch.full((K // 2, N), byte_val, dtype=torch.int8, device=device)
|
|
|
|
# Uniform SF = 1.0
|
|
x_sf = torch.ones(M, K // 16, dtype=torch.float8_e4m3fn, device=device)
|
|
w_sf = torch.ones(K // 16, N, dtype=torch.float8_e4m3fn, device=device)
|
|
|
|
out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
|
|
|
|
# Reference: all x = 1.5, all w = 1.5, output = 1.5 * 1.5 * 32 = 72.0
|
|
print(f"NVFP4 output first 8: {out[0, :8].tolist()}")
|
|
print(f"Expected: 72.0 for all elements")
|
|
print(f"Actual mean: {out.float().mean().item():.4f}")
|
|
print(f"All same? {torch.allclose(out, out[0,0].expand_as(out), atol=0.01)}")
|