Files
nvfp4-megamoe-kernel/test_forward_map.py

85 lines
4.0 KiB
Python

"""Test: verify that layout_sf(make_coord(m, k*16)) produces correct dst indices.
If the forward mapping is wrong, this will show it."""
import torch, sys
sys.path.insert(0, 'src')
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
torch.manual_seed(42)
device = "cuda"
# Test 1: all-ones SF (should still give cosine 1.0)
M, N, K = 1, 32, 32
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
w_fp4 = w_fp4.T; w_sf = w_sf.T
# Test with uniform SF
x_sf_ones = torch.ones_like(x_sf)
w_sf_ones = torch.ones_like(w_sf)
out_uni = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf_ones, w_fp4, w_sf_ones, M, N, K, alpha=1.0)
# Dequant reference with uniform SF
x_u8 = x_fp4.view(torch.uint8)
lo = (x_u8 & 0x0F).long(); hi = ((x_u8 >> 4) & 0x0F).long()
x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
x_deq = ((x_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)]
x_recon = (x_deq * 1.0).to(torch.bfloat16)
w_u8 = w_fp4.view(torch.uint8)
wlo = (w_u8 & 0x0F).long(); whi = ((w_u8 >> 4) & 0x0F).long()
w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1])
w_deq = ((w_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)]
w_recon = (w_deq * 1.0).to(torch.bfloat16)
ref_uni = torch.nn.functional.linear(x_recon, w_recon.T)
cos_uni = torch.nn.functional.cosine_similarity(out_uni.float(), ref_uni.float(), dim=-1).mean().item()
print(f"Uniform SF: cosine={cos_uni:.6f}")
# Test 2: try the prepack path
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import prepack_sfb
w_sf_packed = prepack_sfb(w_sf, M, N, K)
out_prepacked = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf_packed, M, N, K, alpha=1.0, sfb_prepacked=True)
# Full dequant reference
x_recon_real = (x_deq * x_sf.to(torch.float32).repeat_interleave(16, dim=-1)).to(torch.bfloat16)
w_recon_real = (w_deq * w_sf.to(torch.float32).repeat_interleave(16, dim=0)).to(torch.bfloat16)
ref_real = torch.nn.functional.linear(x_recon_real, w_recon_real.T)
cos_pre = torch.nn.functional.cosine_similarity(out_prepacked.float(), ref_real.float(), dim=-1).mean().item()
print(f"Prepacked SFB: cosine={cos_pre:.6f}")
# Test 3: without prepack (on-the-fly SFB remap)
out_live = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
cos_live = torch.nn.functional.cosine_similarity(out_live.float(), ref_real.float(), dim=-1).mean().item()
print(f"Live SFB: cosine={cos_live:.6f}")
# Test 4: N=128, K=256 (bigger dims)
M2, N2, K2 = 1, 128, 256
x2 = torch.randn(M2, K2, dtype=torch.bfloat16, device=device) * 2.0
w2 = torch.randn(K2, N2, dtype=torch.bfloat16, device=device) * 0.5
x2_fp4, x2_sf = _quantize_to_e2m1(x2.float())
w2_fp4, w2_sf = _quantize_to_e2m1(w2.T.float())
w2_fp4 = w2_fp4.T; w2_sf = w2_sf.T
out2 = cutlass_nvfp4_blockscaled_gemm(x2_fp4, x2_sf, w2_fp4, w2_sf, M2, N2, K2, alpha=1.0)
# Dequant ref
x2_u8 = x2_fp4.view(torch.uint8)
lo2 = (x2_u8 & 0x0F).long(); hi2 = ((x2_u8 >> 4) & 0x0F).long()
x2_nib = torch.stack([lo2, hi2], dim=-1).reshape(M2, -1)
x2_deq = ((x2_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(x2_nib & 0x07)]
x2_recon = (x2_deq * x2_sf.to(torch.float32).repeat_interleave(16, dim=-1)).to(torch.bfloat16)
w2_u8 = w2_fp4.view(torch.uint8)
w2lo = (w2_u8 & 0x0F).long(); w2hi = ((w2_u8 >> 4) & 0x0F).long()
w2_nib = torch.stack([w2lo, w2hi], dim=-1).reshape(w2_u8.shape[0]*2, w2_u8.shape[1])
w2_deq = ((w2_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(w2_nib & 0x07)]
w2_recon = (w2_deq * w2_sf.to(torch.float32).repeat_interleave(16, dim=0)).to(torch.bfloat16)
ref2 = torch.nn.functional.linear(x2_recon, w2_recon.T)
cos2 = torch.nn.functional.cosine_similarity(out2.float(), ref2.float(), dim=-1).mean().item()
print(f"M=1 N=128 K=256: cosine={cos2:.6f}")