diff --git a/test_b_layout.py b/test_b_layout.py new file mode 100644 index 00000000..fb50896d --- /dev/null +++ b/test_b_layout.py @@ -0,0 +1,62 @@ +"""Test: verify B matrix weight layout by using one-hot A. +If A is one-hot (only element j is nonzero), output = row j of B. +We can verify that column n of the output matches checkpoint row n. +""" +import torch, sys +sys.path.insert(0, 'src') +from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm +from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES + +torch.manual_seed(123) +device = "cuda" + +# Small dimensions: M=1, N=4, K=16 +M, N, K = 1, 4, 16 + +# Create a weight matrix with unique values so we can identify each element +# Use BF16 values that map to distinct E2M1 values +# E2M1 magnitudes: [0, 0.5, 1, 1.5, 2, 3, 4, 6] +# Create w_bf16 where each column has a different pattern +w_bf16 = torch.tensor([ + [0.5, 1.0, 2.0, 3.0], # K=0 + [1.5, 0.5, 1.0, 2.0], # K=1 + [2.0, 3.0, 0.5, 1.0], # K=2 + [1.0, 2.0, 3.0, 0.5], # K=3 + [3.0, 1.0, 0.5, 2.0], # K=4 + [0.5, 3.0, 2.0, 1.0], # K=5 + [2.0, 0.5, 1.0, 3.0], # K=6 + [1.0, 2.0, 3.0, 0.5], # K=7 + [3.0, 1.0, 2.0, 0.5], # K=8 + [0.5, 2.0, 1.0, 3.0], # K=9 + [2.0, 0.5, 3.0, 1.0], # K=10 + [1.0, 3.0, 0.5, 2.0], # K=11 + [3.0, 2.0, 1.0, 0.5], # K=12 + [0.5, 1.0, 3.0, 2.0], # K=13 + [2.0, 3.0, 0.5, 1.0], # K=14 + [1.0, 0.5, 2.0, 3.0], # K=15 +], dtype=torch.bfloat16, device=device) # (K, N) + +# Quantize weights +w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float()) # (N, K//2), (N, K//16) +w_fp4 = w_fp4.T.contiguous() # (K//2, N) = (8, 4) +w_sf = w_sf.T.contiguous() # (K//16, N) = (1, 4) + +# BF16 reference: just sum the rows +ref = w_bf16.sum(dim=0) # (N,) = sum over K +print(f"BF16 reference sum: {ref.tolist()}") + +# Now run GEMM with all-ones A (sum all K elements) +x_bf16 = torch.ones(M, K, dtype=torch.bfloat16, device=device) +x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float()) + +out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0) +print(f"NVFP4 output: {out[0].tolist()}") + +# The outputs should be proportional to the column sums +# Since we're using quantized values, they won't match exactly +# But the RANK ORDER should match (column with highest sum should have highest output) +ref_order = torch.argsort(ref, descending=True).tolist() +out_order = torch.argsort(out[0], descending=True).tolist() +print(f"Reference rank order: {ref_order}") +print(f"NVFP4 rank order: {out_order}") +print(f"Rank order match: {ref_order == out_order}")