""" BF16 Pair-Swap Test: compare kernel output against reference with V rows swapped in even/odd pairs (simulating BF16 packing swap within F32 words). Also tries to identify the exact column permutation. """ import torch, sys sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from test_stage_b_v7 import StageBIdentitySoftmax torch.manual_seed(42) m, n, k = 128, 128, 128 q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') qf = q[:,:,0].float() kvf = kv[:,:,0].float() ref = qf @ kvf.T @ kvf # Pair-swapped: V rows 0↔1, 2↔3, 4↔5, ... V_swap = kvf.clone() V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone() ref_swap = qf @ kvf.T @ V_swap mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) print('Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mC, stream) print('Running...', flush=True) compiled(mQ, mK, mC, stream) torch.cuda.synchronize() out = c[:,:,0].float() cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item() print(f'\nCosine with standard ref: {cos_ref:.6f}') print(f'Cosine with pair-swapped ref: {cos_swap:.6f}') # Column permutation identification print('\nColumn permutation (output col j matches ref col k):') perm = [] for j in range(min(16, 128)): best_cos = 0 best_k = -1 for k in range(128): c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item() if c_val > best_cos: best_cos = c_val best_k = k perm.append(best_k) if j < 16: print(f' out col {j} -> ref col {best_k} (cos={best_cos:.4f})') # Check if permutation has a pattern print(f'\nPermutation (first 16): {perm}') diffs = [perm[j+1] - perm[j] for j in range(len(perm)-1)] print(f'Permutation diffs: {diffs}') # Check: is perm[j] = j with every pair swapped? (0→1, 1→0, 2→3, 3→2, ...) expected_pair_swap = [j^1 for j in range(128)] # XOR with 1 swaps even/odd matches_pair_swap = all(perm[j] == expected_pair_swap[j] for j in range(len(perm))) print(f'Matches pair-swap (0↔1, 2↔3, ...): {matches_pair_swap}') # Also check the full permutation for all 128 columns full_perm = [] for j in range(128): best_cos = 0 best_k = -1 for k in range(128): c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item() if c_val > best_cos: best_cos = c_val best_k = k full_perm.append(best_k) print(f'\nFull permutation: {full_perm}') # Check if it's pair-swap for all columns all_pair_swap = all(full_perm[j] == (j^1) for j in range(128)) print(f'All columns match pair-swap: {all_pair_swap}') # If not pair-swap, what's the pattern? if not all_pair_swap: # Check if it's a 16-element block permutation for block in range(8): block_perm = full_perm[block*16:(block+1)*16] print(f' Block {block}: {block_perm}')