From af50e98fe926e84270ea3d05d269c770da2f5dd9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 23:52:22 +0000 Subject: [PATCH] test: B layout test with N=128 K=256 --- test_b_layout.py | 79 ++++++++++++++++++++---------------------------- 1 file changed, 33 insertions(+), 46 deletions(-) diff --git a/test_b_layout.py b/test_b_layout.py index fb50896d..613eff7d 100644 --- a/test_b_layout.py +++ b/test_b_layout.py @@ -1,62 +1,49 @@ -"""Test: verify B matrix weight layout by using one-hot A. -If A is one-hot (only element j is nonzero), output = row j of B. -We can verify that column n of the output matches checkpoint row n. +"""Test: verify B matrix weight layout with larger dimensions. +Use M=1, N=128, K=256 — big enough for CUTLASS tiles. +Fill B columns with distinct patterns and check if GEMM output +matches the expected column sums. """ import torch, sys sys.path.insert(0, 'src') from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm -from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES +from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1 torch.manual_seed(123) device = "cuda" -# Small dimensions: M=1, N=4, K=16 -M, N, K = 1, 4, 16 +M, N, K = 1, 128, 256 -# Create a weight matrix with unique values so we can identify each element -# Use BF16 values that map to distinct E2M1 values -# E2M1 magnitudes: [0, 0.5, 1, 1.5, 2, 3, 4, 6] -# Create w_bf16 where each column has a different pattern -w_bf16 = torch.tensor([ - [0.5, 1.0, 2.0, 3.0], # K=0 - [1.5, 0.5, 1.0, 2.0], # K=1 - [2.0, 3.0, 0.5, 1.0], # K=2 - [1.0, 2.0, 3.0, 0.5], # K=3 - [3.0, 1.0, 0.5, 2.0], # K=4 - [0.5, 3.0, 2.0, 1.0], # K=5 - [2.0, 0.5, 1.0, 3.0], # K=6 - [1.0, 2.0, 3.0, 0.5], # K=7 - [3.0, 1.0, 2.0, 0.5], # K=8 - [0.5, 2.0, 1.0, 3.0], # K=9 - [2.0, 0.5, 3.0, 1.0], # K=10 - [1.0, 3.0, 0.5, 2.0], # K=11 - [3.0, 2.0, 1.0, 0.5], # K=12 - [0.5, 1.0, 3.0, 2.0], # K=13 - [2.0, 3.0, 0.5, 1.0], # K=14 - [1.0, 0.5, 2.0, 3.0], # K=15 -], dtype=torch.bfloat16, device=device) # (K, N) +# Create weight with column-dependent pattern +# Column j has value (j % 7) * 0.5 + 0.5 to get distinct E2M1 values +w_bf16 = torch.zeros(K, N, dtype=torch.bfloat16, device=device) +for j in range(N): + w_bf16[:, j] = (j % 7) * 0.5 + 0.5 -# Quantize weights -w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float()) # (N, K//2), (N, K//16) -w_fp4 = w_fp4.T.contiguous() # (K//2, N) = (8, 4) -w_sf = w_sf.T.contiguous() # (K//16, N) = (1, 4) +# Quantize +w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float()) +w_fp4 = w_fp4.T.contiguous() +w_sf = w_sf.T.contiguous() -# BF16 reference: just sum the rows -ref = w_bf16.sum(dim=0) # (N,) = sum over K -print(f"BF16 reference sum: {ref.tolist()}") - -# Now run GEMM with all-ones A (sum all K elements) +# All-ones A (sum all K elements) x_bf16 = torch.ones(M, K, dtype=torch.bfloat16, device=device) x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float()) out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0) -print(f"NVFP4 output: {out[0].tolist()}") -# The outputs should be proportional to the column sums -# Since we're using quantized values, they won't match exactly -# But the RANK ORDER should match (column with highest sum should have highest output) -ref_order = torch.argsort(ref, descending=True).tolist() -out_order = torch.argsort(out[0], descending=True).tolist() -print(f"Reference rank order: {ref_order}") -print(f"NVFP4 rank order: {out_order}") -print(f"Rank order match: {ref_order == out_order}") +# Each column j has the same value repeated K times +# So output[j] should be proportional to K * column_value +# Column values cycle: 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 0.5, ... +# So columns with same j%7 should have the same output +print(f"Output first 14: {[f'{v:.2f}' for v in out[0, :14].tolist()]}") +print(f"Expected pattern: columns 0,7 should match; 1,8 should match; etc") + +# Check: columns with same j%7 should be close +for mod_val in range(7): + cols = [j for j in range(N) if j % 7 == mod_val] + vals = out[0, cols] + if len(cols) > 1: + spread = (vals.max() - vals.min()).item() + if spread > 0.5: + print(f"WARNING: j%7={mod_val} spread={spread:.4f} — columns with same weight have different outputs!") + else: + print(f"j%7={mod_val}: mean={vals.mean():.4f} spread={spread:.4f} OK")