Files
nvfp4-megamoe-kernel/test_gemm_sweep.py

78 lines
2.9 KiB
Python

"""Minimal test: CUTLASS NVFP4 GEMM with simple dimensions to isolate the bug.
Test 1: Small dimensions (M=128, N=256, K=512) — should match the original working test
Test 2: Medium dimensions (M=4, N=1024, K=2048)
Test 3: Real MoE dimensions (M=1, N=6144, K=7168)
"""
import torch
import sys
sys.path.insert(0, 'src')
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
torch.manual_seed(42)
device = "cuda"
def test_gemm(M, N, K, label):
K_half = K // 2
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 2.0
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) * 0.5
# Quantize
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
# Weight: quantize transposed to get (K_half, N) layout
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
w_fp4 = w_fp4.T # (K_half, N)
w_sf = w_sf.T # (K//16, N)
# Dequant reference
x_u8 = x_fp4.view(torch.uint8)
lo = (x_u8 & 0x0F).long()
hi = ((x_u8 >> 4) & 0x0F).long()
x_nib = torch.stack([lo, hi], dim=-1).reshape(M, -1)
x_signs = (x_nib >> 3).float() * -2 + 1
x_mags = _E2M1_MAGNITUDES.to(device)[(x_nib & 0x07)]
x_deq = x_signs * x_mags
sf_exp = x_sf.to(torch.float32).repeat_interleave(16, dim=-1)
x_recon = (x_deq * sf_exp).to(torch.bfloat16)
w_u8 = w_fp4.view(torch.uint8)
wlo = (w_u8 & 0x0F).long()
whi = ((w_u8 >> 4) & 0x0F).long()
w_nib = torch.stack([wlo, whi], dim=-1).reshape(w_u8.shape[0]*2, w_u8.shape[1])
w_signs = (w_nib >> 3).float() * -2 + 1
w_mags = _E2M1_MAGNITUDES.to(device)[(w_nib & 0x07)]
w_deq = w_signs * w_mags
w_sf_exp = w_sf.to(torch.float32).repeat_interleave(16, dim=0)
w_recon = (w_deq * w_sf_exp).to(torch.bfloat16)
quant_ref = torch.nn.functional.linear(x_recon, w_recon.T)
nvfp4_out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item()
print(f"{label}: M={M} N={N} K={K} cosine={cos:.6f} mse={mse:.4e} nvfp4_amax={nvfp4_out.abs().max():.2e} ref_amax={quant_ref.abs().max():.2e}")
# Test 1: Small (like original working test)
test_gemm(128, 256, 512, "SMALL")
test_gemm(128, 512, 1024, "MEDIUM")
# Test 2: N and K divisible by 128 (tile alignment)
test_gemm(1, 128, 256, "TINY")
test_gemm(1, 256, 512, "SMALL-M1")
test_gemm(1, 1024, 2048, "MED-M1")
# Test 3: Real MoE dimensions
test_gemm(1, 6144, 7168, "REAL-L1")
test_gemm(1, 7168, 3072, "REAL-L2")
# Test 4: N=6144 K=7168 with M=128 (to see if M matters at these dims)
test_gemm(128, 6144, 7168, "REAL-L1-M128")
# Test 5: Aligned versions
test_gemm(1, 6144, 7168, "REAL-L1") # same, for reference
test_gemm(1, 6144, 7168, "REAL-L1-no-alpha") # alpha=1.0 already