Files
nvfp4-megamoe-kernel/tests/test_b_layout.py
biondizzle 303b6a8993 cleanup: move useful tests to tests/, nuke stale debug tests
Kept (moved to tests/):
- test_uniform_fp4.py — proves GEMM math (72.0 = 1.5² × K)
- test_b_layout.py — proves B matrix column layout
- test_quick_rand.py — quick GEMM sanity check

Removed (stale SF remap debug artifacts):
- test_forward_map.py, test_gemm_sweep.py, test_m1_gemm.py
- test_minimal_gemm.py, test_rand_gemm.py, test_sf_check.py
- test_sf_remap.py, test_sf_signed.py, test_sf_layout_diag.cu
2026-05-16 02:14:37 +00:00

50 lines
1.9 KiB
Python

"""Test: verify B matrix weight layout with larger dimensions.
Use M=1, N=128, K=256 — big enough for CUTLASS tiles.
Fill B columns with distinct patterns and check if GEMM output
matches the expected column sums.
"""
import torch, sys
sys.path.insert(0, 'src')
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1
torch.manual_seed(123)
device = "cuda"
M, N, K = 1, 128, 256
# Create weight with column-dependent pattern
# Column j has value (j % 7) * 0.5 + 0.5 to get distinct E2M1 values
w_bf16 = torch.zeros(K, N, dtype=torch.bfloat16, device=device)
for j in range(N):
w_bf16[:, j] = (j % 7) * 0.5 + 0.5
# Quantize
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
w_fp4 = w_fp4.T.contiguous()
w_sf = w_sf.T.contiguous()
# All-ones A (sum all K elements)
x_bf16 = torch.ones(M, K, dtype=torch.bfloat16, device=device)
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
# Each column j has the same value repeated K times
# So output[j] should be proportional to K * column_value
# Column values cycle: 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 0.5, ...
# So columns with same j%7 should have the same output
print(f"Output first 14: {[f'{v:.2f}' for v in out[0, :14].tolist()]}")
print(f"Expected pattern: columns 0,7 should match; 1,8 should match; etc")
# Check: columns with same j%7 should be close
for mod_val in range(7):
cols = [j for j in range(N) if j % 7 == mod_val]
vals = out[0, cols]
if len(cols) > 1:
spread = (vals.max() - vals.min()).item()
if spread > 0.5:
print(f"WARNING: j%7={mod_val} spread={spread:.4f} — columns with same weight have different outputs!")
else:
print(f"j%7={mod_val}: mean={vals.mean():.4f} spread={spread:.4f} OK")