Key finding: C-fragment and A-fragment use different physical TMEM address mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16. Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999) Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02) Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
96 lines
3.6 KiB
Python
96 lines
3.6 KiB
Python
"""
|
|
BF16 Pair-Swap Test: compare kernel output against reference with V rows
|
|
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
|
|
Also tries to identify the exact column permutation.
|
|
"""
|
|
import torch, sys
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
|
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from test_stage_b_v7 import StageBIdentitySoftmax
|
|
|
|
torch.manual_seed(42)
|
|
m, n, k = 128, 128, 128
|
|
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
qf = q[:,:,0].float()
|
|
kvf = kv[:,:,0].float()
|
|
ref = qf @ kvf.T @ kvf
|
|
|
|
# Pair-swapped: V rows 0↔1, 2↔3, 4↔5, ...
|
|
V_swap = kvf.clone()
|
|
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
|
|
ref_swap = qf @ kvf.T @ V_swap
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
|
print('Compiling...', flush=True)
|
|
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
|
print('Running...', flush=True)
|
|
compiled(mQ, mK, mC, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
out = c[:,:,0].float()
|
|
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
|
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
|
|
|
|
print(f'\nCosine with standard ref: {cos_ref:.6f}')
|
|
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
|
|
|
|
# Column permutation identification
|
|
print('\nColumn permutation (output col j matches ref col k):')
|
|
perm = []
|
|
for j in range(min(16, 128)):
|
|
best_cos = 0
|
|
best_k = -1
|
|
for k in range(128):
|
|
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
|
|
if c_val > best_cos:
|
|
best_cos = c_val
|
|
best_k = k
|
|
perm.append(best_k)
|
|
if j < 16:
|
|
print(f' out col {j} -> ref col {best_k} (cos={best_cos:.4f})')
|
|
|
|
# Check if permutation has a pattern
|
|
print(f'\nPermutation (first 16): {perm}')
|
|
diffs = [perm[j+1] - perm[j] for j in range(len(perm)-1)]
|
|
print(f'Permutation diffs: {diffs}')
|
|
|
|
# Check: is perm[j] = j with every pair swapped? (0→1, 1→0, 2→3, 3→2, ...)
|
|
expected_pair_swap = [j^1 for j in range(128)] # XOR with 1 swaps even/odd
|
|
matches_pair_swap = all(perm[j] == expected_pair_swap[j] for j in range(len(perm)))
|
|
print(f'Matches pair-swap (0↔1, 2↔3, ...): {matches_pair_swap}')
|
|
|
|
# Also check the full permutation for all 128 columns
|
|
full_perm = []
|
|
for j in range(128):
|
|
best_cos = 0
|
|
best_k = -1
|
|
for k in range(128):
|
|
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
|
|
if c_val > best_cos:
|
|
best_cos = c_val
|
|
best_k = k
|
|
full_perm.append(best_k)
|
|
|
|
print(f'\nFull permutation: {full_perm}')
|
|
# Check if it's pair-swap for all columns
|
|
all_pair_swap = all(full_perm[j] == (j^1) for j in range(128))
|
|
print(f'All columns match pair-swap: {all_pair_swap}')
|
|
|
|
# If not pair-swap, what's the pattern?
|
|
if not all_pair_swap:
|
|
# Check if it's a 16-element block permutation
|
|
for block in range(8):
|
|
block_perm = full_perm[block*16:(block+1)*16]
|
|
print(f' Block {block}: {block_perm}')
|