diff --git a/test_sf_check.py b/test_sf_check.py new file mode 100644 index 00000000..0611ace3 --- /dev/null +++ b/test_sf_check.py @@ -0,0 +1,45 @@ +"""Check if size != cosize for small dimensions.""" +import torch, sys +sys.path.insert(0, 'src') + +# We need to construct the layouts to check +# Replicate what the CU code does +import cutlass_nvfp4_gemm._C as _C +# Actually we can't easily call CUTE from Python. +# Let's just test with progressively larger N until size != cosize matters. + +# Alternative approach: test the _C.forward directly and check SF remap +# by passing known SF values and seeing if they end up in the right places. + +# Simpler: test with M=128, N=128, K=256 where tile padding definitely kicks in +from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm +from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES + +torch.manual_seed(42) +device = "cuda" + +# Test at dimensions where tiling definitely matters +for M, N, K in [(128, 128, 256), (128, 256, 512), (1, 6144, 7168)]: + x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0 + w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5 + + x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float()) + w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float()) + w_fp4 = w_fp4.T; w_sf = w_sf.T + + x_u8 = x_fp4.view(torch.uint8) + lo = (x_u8 & 0x0F).long(); hi = ((x_u8 >> 4) & 0x0F).long() + x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1) + x_deq = ((x_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)] + x_recon = (x_deq * x_sf.to(torch.float32).repeat_interleave(16, dim=-1)).to(torch.bfloat16) + + w_u8 = w_fp4.view(torch.uint8) + wlo = (w_u8 & 0x0F).long(); whi = ((w_u8 >> 4) & 0x0F).long() + w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1]) + w_deq = ((w_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)] + w_recon = (w_deq * w_sf.to(torch.float32).repeat_interleave(16, dim=0)).to(torch.bfloat16) + + quant_ref = torch.nn.functional.linear(x_recon, w_recon.T) + nvfp4_out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0) + cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item() + print(f"M={M} N={N} K={K} cosine={cos:.6f}") diff --git a/test_sf_remap.py b/test_sf_remap.py new file mode 100644 index 00000000..bcf70b81 --- /dev/null +++ b/test_sf_remap.py @@ -0,0 +1,75 @@ +"""Verify the SF remap by comparing CUTLASS output with and without SF remap. + +Strategy: +1. Run GEMM with identity SF (all 1.0) — both A and B +2. Run GEMM with a single non-1.0 SF value — see if it affects the right output elements +3. This tells us if the remap is placing SF values correctly + +Actually, simpler: run GEMM with prepack_sfb=False (remap on the fly) and +prepack_sfb=True (pre-remapped), compare. If they differ, the remap is wrong. +""" +import torch, sys +sys.path.insert(0, 'src') +from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import ( + cutlass_nvfp4_blockscaled_gemm, prepack_sfb +) +from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES + +torch.manual_seed(42) +device = "cuda" + +M, N, K = 1, 32, 32 +x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0 +w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5 + +x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float()) +w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float()) +w_fp4 = w_fp4.T; w_sf = w_sf.T + +# Test 1: with remap (sfb_prepacked=False) +out_remap = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0, sfb_prepacked=False) + +# Test 2: with prepacked SFB +w_sf_packed = prepack_sfb(w_sf, M, N, K) +out_prepacked = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf_packed, M, N, K, alpha=1.0, sfb_prepacked=True) + +print(f"Remap output first 8: {out_remap[0,:8].tolist()}") +print(f"Prepacked output first 8: {out_prepacked[0,:8].tolist()}") +print(f"Match: {torch.allclose(out_remap, out_prepacked, atol=0.01)}") +diff = (out_remap - out_prepacked).abs().max().item() +print(f"Max diff: {diff:.4e}") + +# Test 3: uniform SF — should match perfectly +x_sf_ones = torch.ones_like(x_sf) +w_sf_ones = torch.ones_like(w_sf) +out_uni_remap = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf_ones, w_fp4, w_sf_ones, M, N, K, alpha=1.0, sfb_prepacked=False) +out_uni_pre = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf_ones, w_fp4, prepack_sfb(w_sf_ones, M, N, K), M, N, K, alpha=1.0, sfb_prepacked=True) +print(f"\nUniform SF remap vs prepacked: {torch.allclose(out_uni_remap, out_uni_pre, atol=0.01)}") + +# Test 4: SFA remap — try with all-1.0 SFA and actual SFB, vs actual SFA and all-1.0 SFB +# This isolates which remap (SFA or SFB) is broken +out_real_sfa = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf_ones, M, N, K, alpha=1.0) +out_real_sfb = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf_ones, w_fp4, w_sf, M, N, K, alpha=1.0) + +# Compute BF16 references +x_u8 = x_fp4.view(torch.uint8) +lo = (x_u8 & 0x0F).long(); hi = ((x_u8 >> 4) & 0x0F).long() +x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1) +x_deq = ((x_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)] +x_recon = (x_deq * x_sf.to(torch.float32).repeat_interleave(16, dim=-1)).to(torch.bfloat16) +x_recon_ones = (x_deq * 1.0).to(torch.bfloat16) # uniform SF + +w_u8 = w_fp4.view(torch.uint8) +wlo = (w_u8 & 0x0F).long(); whi = ((w_u8 >> 4) & 0x0F).long() +w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1]) +w_deq = ((w_nib >> 3).float() * -2 + 1) * _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)] +w_recon = (w_deq * w_sf.to(torch.float32).repeat_interleave(16, dim=0)).to(torch.bfloat16) +w_recon_ones = (w_deq * 1.0).to(torch.bfloat16) + +ref_real_sfa = torch.nn.functional.linear(x_recon, w_recon_ones.T) +ref_real_sfb = torch.nn.functional.linear(x_recon_ones, w_recon.T) + +cos_sfa = torch.nn.functional.cosine_similarity(out_real_sfa.float(), ref_real_sfa.float(), dim=-1).mean().item() +cos_sfb = torch.nn.functional.cosine_similarity(out_real_sfb.float(), ref_real_sfb.float(), dim=-1).mean().item() +print(f"\nSFA remap cosine (real SFA, uniform SFB): {cos_sfa:.6f}") +print(f"SFB remap cosine (uniform SFA, real SFB): {cos_sfb:.6f}")