test: B layout test with N=128 K=256
This commit is contained in:
@@ -1,62 +1,49 @@
|
||||
"""Test: verify B matrix weight layout by using one-hot A.
|
||||
If A is one-hot (only element j is nonzero), output = row j of B.
|
||||
We can verify that column n of the output matches checkpoint row n.
|
||||
"""Test: verify B matrix weight layout with larger dimensions.
|
||||
Use M=1, N=128, K=256 — big enough for CUTLASS tiles.
|
||||
Fill B columns with distinct patterns and check if GEMM output
|
||||
matches the expected column sums.
|
||||
"""
|
||||
import torch, sys
|
||||
sys.path.insert(0, 'src')
|
||||
from nvfp4_megamoe_kernel.cutlass_nvfp4_gemm.kernel import cutlass_nvfp4_blockscaled_gemm
|
||||
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1, _E2M1_MAGNITUDES
|
||||
from nvfp4_megamoe_kernel.nvfp4_mega_moe import _quantize_to_e2m1
|
||||
|
||||
torch.manual_seed(123)
|
||||
device = "cuda"
|
||||
|
||||
# Small dimensions: M=1, N=4, K=16
|
||||
M, N, K = 1, 4, 16
|
||||
M, N, K = 1, 128, 256
|
||||
|
||||
# Create a weight matrix with unique values so we can identify each element
|
||||
# Use BF16 values that map to distinct E2M1 values
|
||||
# E2M1 magnitudes: [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# Create w_bf16 where each column has a different pattern
|
||||
w_bf16 = torch.tensor([
|
||||
[0.5, 1.0, 2.0, 3.0], # K=0
|
||||
[1.5, 0.5, 1.0, 2.0], # K=1
|
||||
[2.0, 3.0, 0.5, 1.0], # K=2
|
||||
[1.0, 2.0, 3.0, 0.5], # K=3
|
||||
[3.0, 1.0, 0.5, 2.0], # K=4
|
||||
[0.5, 3.0, 2.0, 1.0], # K=5
|
||||
[2.0, 0.5, 1.0, 3.0], # K=6
|
||||
[1.0, 2.0, 3.0, 0.5], # K=7
|
||||
[3.0, 1.0, 2.0, 0.5], # K=8
|
||||
[0.5, 2.0, 1.0, 3.0], # K=9
|
||||
[2.0, 0.5, 3.0, 1.0], # K=10
|
||||
[1.0, 3.0, 0.5, 2.0], # K=11
|
||||
[3.0, 2.0, 1.0, 0.5], # K=12
|
||||
[0.5, 1.0, 3.0, 2.0], # K=13
|
||||
[2.0, 3.0, 0.5, 1.0], # K=14
|
||||
[1.0, 0.5, 2.0, 3.0], # K=15
|
||||
], dtype=torch.bfloat16, device=device) # (K, N)
|
||||
# Create weight with column-dependent pattern
|
||||
# Column j has value (j % 7) * 0.5 + 0.5 to get distinct E2M1 values
|
||||
w_bf16 = torch.zeros(K, N, dtype=torch.bfloat16, device=device)
|
||||
for j in range(N):
|
||||
w_bf16[:, j] = (j % 7) * 0.5 + 0.5
|
||||
|
||||
# Quantize weights
|
||||
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float()) # (N, K//2), (N, K//16)
|
||||
w_fp4 = w_fp4.T.contiguous() # (K//2, N) = (8, 4)
|
||||
w_sf = w_sf.T.contiguous() # (K//16, N) = (1, 4)
|
||||
# Quantize
|
||||
w_fp4, w_sf = _quantize_to_e2m1(w_bf16.T.float())
|
||||
w_fp4 = w_fp4.T.contiguous()
|
||||
w_sf = w_sf.T.contiguous()
|
||||
|
||||
# BF16 reference: just sum the rows
|
||||
ref = w_bf16.sum(dim=0) # (N,) = sum over K
|
||||
print(f"BF16 reference sum: {ref.tolist()}")
|
||||
|
||||
# Now run GEMM with all-ones A (sum all K elements)
|
||||
# All-ones A (sum all K elements)
|
||||
x_bf16 = torch.ones(M, K, dtype=torch.bfloat16, device=device)
|
||||
x_fp4, x_sf = _quantize_to_e2m1(x_bf16.float())
|
||||
|
||||
out = cutlass_nvfp4_blockscaled_gemm(x_fp4, x_sf, w_fp4, w_sf, M, N, K, alpha=1.0)
|
||||
print(f"NVFP4 output: {out[0].tolist()}")
|
||||
|
||||
# The outputs should be proportional to the column sums
|
||||
# Since we're using quantized values, they won't match exactly
|
||||
# But the RANK ORDER should match (column with highest sum should have highest output)
|
||||
ref_order = torch.argsort(ref, descending=True).tolist()
|
||||
out_order = torch.argsort(out[0], descending=True).tolist()
|
||||
print(f"Reference rank order: {ref_order}")
|
||||
print(f"NVFP4 rank order: {out_order}")
|
||||
print(f"Rank order match: {ref_order == out_order}")
|
||||
# Each column j has the same value repeated K times
|
||||
# So output[j] should be proportional to K * column_value
|
||||
# Column values cycle: 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 0.5, ...
|
||||
# So columns with same j%7 should have the same output
|
||||
print(f"Output first 14: {[f'{v:.2f}' for v in out[0, :14].tolist()]}")
|
||||
print(f"Expected pattern: columns 0,7 should match; 1,8 should match; etc")
|
||||
|
||||
# Check: columns with same j%7 should be close
|
||||
for mod_val in range(7):
|
||||
cols = [j for j in range(N) if j % 7 == mod_val]
|
||||
vals = out[0, cols]
|
||||
if len(cols) > 1:
|
||||
spread = (vals.max() - vals.min()).item()
|
||||
if spread > 0.5:
|
||||
print(f"WARNING: j%7={mod_val} spread={spread:.4f} — columns with same weight have different outputs!")
|
||||
else:
|
||||
print(f"j%7={mod_val}: mean={vals.mean():.4f} spread={spread:.4f} OK")
|
||||
|
||||
Reference in New Issue
Block a user