Files
nvfp4-megamoe-kernel/tests/test_pair_swap2.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

96 lines
3.6 KiB
Python

"""
BF16 Pair-Swap Test: compare kernel output against reference with V rows
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
Also tries to identify the exact column permutation.
"""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from test_stage_b_v7 import StageBIdentitySoftmax
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
# Pair-swapped: V rows 0↔1, 2↔3, 4↔5, ...
V_swap = kvf.clone()
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
ref_swap = qf @ kvf.T @ V_swap
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
print(f'\nCosine with standard ref: {cos_ref:.6f}')
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
# Column permutation identification
print('\nColumn permutation (output col j matches ref col k):')
perm = []
for j in range(min(16, 128)):
best_cos = 0
best_k = -1
for k in range(128):
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
if c_val > best_cos:
best_cos = c_val
best_k = k
perm.append(best_k)
if j < 16:
print(f' out col {j} -> ref col {best_k} (cos={best_cos:.4f})')
# Check if permutation has a pattern
print(f'\nPermutation (first 16): {perm}')
diffs = [perm[j+1] - perm[j] for j in range(len(perm)-1)]
print(f'Permutation diffs: {diffs}')
# Check: is perm[j] = j with every pair swapped? (0→1, 1→0, 2→3, 3→2, ...)
expected_pair_swap = [j^1 for j in range(128)] # XOR with 1 swaps even/odd
matches_pair_swap = all(perm[j] == expected_pair_swap[j] for j in range(len(perm)))
print(f'Matches pair-swap (0↔1, 2↔3, ...): {matches_pair_swap}')
# Also check the full permutation for all 128 columns
full_perm = []
for j in range(128):
best_cos = 0
best_k = -1
for k in range(128):
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
if c_val > best_cos:
best_cos = c_val
best_k = k
full_perm.append(best_k)
print(f'\nFull permutation: {full_perm}')
# Check if it's pair-swap for all columns
all_pair_swap = all(full_perm[j] == (j^1) for j in range(128))
print(f'All columns match pair-swap: {all_pair_swap}')
# If not pair-swap, what's the pattern?
if not all_pair_swap:
# Check if it's a 16-element block permutation
for block in range(8):
block_perm = full_perm[block*16:(block+1)*16]
print(f' Block {block}: {block_perm}')