Files
nvfp4-megamoe-kernel/tests/test_pair_swap.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

120 lines
5.3 KiB
Python

"""
BF16 Pair-Swap Test: compare kernel output against reference with V rows
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
"""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from test_stage_b_v7 import StageBIdentitySoftmax
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kvf = kv[:,:,0].float()
# Standard reference: P @ V where P = Q @ K^T
ref = qf @ kvf.T @ kvf
# Pair-swapped reference: if BF16 within each F32 are swapped,
# MMA reads K[2j] and K[2j+1] swapped, which means V rows are swapped in pairs
# Output = P @ V_with_row_pairs_swapped
V_swap = kvf.clone()
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
ref_swap = qf @ kvf.T @ V_swap
# Also try: just the P@K^T with V where every 2 consecutive rows are swapped
# but starting from different offsets
# Try 1-based offset: swap rows (1,2), (3,4), ...
V_swap1 = kvf.clone()
V_swap1[1::2], V_swap1[2::2] = kvf[2::2].clone(), kvf[1::2].clone()
V_swap1[0] = kvf[0] # row 0 unchanged
ref_swap1 = qf @ kvf.T @ V_swap1
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
cos_swap1 = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap1.flatten().unsqueeze(0)).item()
print(f'\nCosine with standard ref: {cos_ref:.6f}')
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
print(f'Cosine with offset-1 swapped: {cos_swap1:.6f}')
# Also try: what if the entire P row is shifted by some amount?
# Or what if P columns are reordered according to the A-fragment's K partitioning?
# The A-fragment has K laid out as (16 inner, 4 outer chunks of 16)
# If the store writes columns sequentially but the MMA reads them in the
# chunked order, we'd get a specific permutation
# The A-fragment K layout: (col, k0, k1) where col=0..15, k0=0..3, k1=0..1
# K position = col + 16*k0 + 64*k1
# If the store writes K sequentially (0,1,2,...,127) but the MMA reads
# in the fragment order, the mapping would be:
# Fragment index (col, k0, k1) -> K position (col + 16*k0 + 64*k1)
# But this is sequential (0,1,2,...), so no permutation.
# UNLESS the store doesn't fill the fragment's K blocks correctly.
# Try: what if only the FIRST 64 K values (not 128) are filled?
# The store writes 64 F32 = 128 BF16. But what if the MMA's A-fragment
# K dimension is 64 (not 128)? Then only 64 K values have P, rest are garbage.
# But we already verified nblk_pv=4 and the fragment covers 128 BF16.
# Try: what if the 64 F32 columns in the store map to 128 BF16 K positions
# but with the packing where BF16[K] and BF16[K+1] come from the same F32 word
# and K is determined by the A-fragment's partition order?
# The A-fragment reads K in groups of 16 (inner K), then 4 outer groups, then 2 BF16 per F32
# So K values are accessed as: [group0_bf16_0, group0_bf16_1, group1_bf16_0, group1_bf16_1, ...]
# where group = (col, k0, k1) and bf16_0/bf16_1 are the two BF16 in each F32 word
# This is getting complex. Let me just check the specific permutation by
# matching output columns to reference columns for row 0.
# For each output column j, find which reference column it matches
# (using the entire row as a signature)
print('\nTrying to identify the column permutation for row 0...')
# Since all values might not be unique, use the dot product with a known vector
# Actually, the simplest: for output[0, j], check if it equals ref[0, k] for any k
# But we need a unique signature. Let's use the output of the ENTIRE row.
# Compare output row i with reference rows
for i in [0, 1, 2]:
best_cos = 0
best_j = -1
for j in range(128):
c = torch.nn.functional.cosine_similarity(out[i].unsqueeze(0), ref[j].unsqueeze(0)).item()
if c > best_cos:
best_cos = c
best_j = j
print(f' Output row {i} best matches ref row {best_j} (cos={best_cos:.4f})')
# Also: compare output column j with reference columns
print('\nColumn matching (output col j vs ref col k):')
for j in [0, 1, 2, 3, 4, 5, 6, 7]:
best_cos = 0
best_k = -1
for k in range(128):
c = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
if c > best_cos:
best_cos = c
best_k = k
print(f' Output col {j} best matches ref col {best_k} (cos={best_cos:.4f})')