Files
nvfp4-megamoe-kernel/test_minimal_gemm.py

69 lines
2.4 KiB
Python

"""Ultra-minimal test: 1 element output, manual verification."""
import torch
import sys
sys.path.insert(0, 'src')
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
torch.manual_seed(42)
device = "cuda"
# Simplest: M=1, N=32, K=32
M, N, K = 1, 32, 32
# All ones in BF16
x_bf16 = torch.ones(M, K, dtype=torch.bfloat16, device=device)
w_bf16 = torch.ones(K, N, dtype=torch.bfloat16, device=device)
# Reference: all-ones @ all-ones = K = 32.0 for every element
ref_out = torch.nn.functional.linear(x_bf16, w_bf16.T)
print(f"BF16 ref (all-ones): {ref_out[0, :8].tolist()} (expected all 32.0)")
# Quantize
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
w_fp4 = w_fp4.T
w_sf = w_sf.T
# Check what quantized values look like
x_u8 = x_fp4.view(torch.uint8)
print(f"x_fp4 first 8 bytes: {x_u8[0, :8].tolist()}")
print(f"x_sf first 4: {x_sf[0, :4].to(torch.float32).tolist()}")
w_u8 = w_fp4.view(torch.uint8)
print(f"w_fp4 first 8 bytes: {w_u8[:8, 0].tolist()}")
print(f"w_sf first 4: {w_sf[:4, 0].to(torch.float32).tolist()}")
# Dequant reference
lo = (x_u8 & 0x0F).long()
hi = ((x_u8 >> 4) & 0x0F).long()
x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
x_signs = (x_nib >> 3).float() * -2 + 1
x_mags = _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)]
x_deq = x_signs * x_mags
sf_exp = x_sf.to(torch.float32).repeat_interleave(16, dim=-1)
x_recon = (x_deq * sf_exp).to(torch.bfloat16)
print(f"x_recon first 8: {x_recon[0, :8].tolist()}")
w_u8 = w_fp4.view(torch.uint8)
wlo = (w_u8 & 0x0F).long()
whi = ((w_u8 >> 4) & 0x0F).long()
w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1])
w_signs = (w_nib >> 3).float() * -2 + 1
w_mags = _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)]
w_deq = w_signs * w_mags
w_sf_exp = w_sf.to(torch.float32).repeat_interleave(16, dim=0)
w_recon = (w_deq * w_sf_exp).to(torch.bfloat16)
print(f"w_recon first 8: {w_recon[:8, 0].tolist()}")
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
print(f"Quant ref first 8: {quant_ref[0, :8].tolist()}")
# CUTLASS GEMM
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
print(f"NVFP4 out first 8: {nvfp4_out[0, :8].tolist()}")
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).item()
print(f"cosine={cos:.6f}")