Kept (moved to tests/): - test_uniform_fp4.py — proves GEMM math (72.0 = 1.5² × K) - test_b_layout.py — proves B matrix column layout - test_quick_rand.py — quick GEMM sanity check Removed (stale SF remap debug artifacts): - test_forward_map.py, test_gemm_sweep.py, test_m1_gemm.py - test_minimal_gemm.py, test_rand_gemm.py, test_sf_check.py - test_sf_remap.py, test_sf_signed.py, test_sf_layout_diag.cu
37 lines
1.7 KiB
Python
37 lines
1.7 KiB
Python
"""Quick random test at N=32 K=32 — if cosize fix works, cosine should be ~1.0"""
|
|
import torch, sys
|
|
sys.path.insert(0, 'src')
|
|
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
|
|
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
|
|
|
|
torch.manual_seed(42)
|
|
device = "cuda"
|
|
M, N, K = 1, 32, 32
|
|
|
|
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
|
|
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
|
|
|
|
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
|
|
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
|
|
w_fp4 = w_fp4.T; w_sf = w_sf.T
|
|
|
|
# Dequant
|
|
x_u8 = x_fp4.view(torch.uint8)
|
|
lo = (x_u8 & 0x0F).long(); hi = ((x_u8 >> 4) & 0x0F).long()
|
|
x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
|
|
x_deq = ((x_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)]
|
|
x_recon = (x_deq * x_sf.to(torch.float32).repeat_interleave(16, dim=-1)).to(torch.bfloat16)
|
|
|
|
w_u8 = w_fp4.view(torch.uint8)
|
|
wlo = (w_u8 & 0x0F).long(); whi = ((w_u8 >> 4) & 0x0F).long()
|
|
w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1])
|
|
w_deq = ((w_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)]
|
|
w_recon = (w_deq * w_sf.to(torch.float32).repeat_interleave(16, dim=0)).to(torch.bfloat16)
|
|
|
|
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
|
|
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
|
|
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
|
|
print(f"M={M} N={N} K={K} cosine={cos:.6f}")
|
|
print(f"NVFP4 first 8: {nvfp4_out[0,:8].tolist()}")
|
|
print(f"REF first 8: {quant_ref[0,:8].tolist()}")
|