""" BF16 Pair-Swap Test: compare kernel output against reference with V rows swapped in even/odd pairs (simulating BF16 packing swap within F32 words). """ import torch, sys sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from test_stage_b_v7 import StageBIdentitySoftmax torch.manual_seed(42) m, n, k = 128, 128, 128 q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') qf = q[:,:,0].float() kvf = kv[:,:,0].float() # Standard reference: P @ V where P = Q @ K^T ref = qf @ kvf.T @ kvf # Pair-swapped reference: if BF16 within each F32 are swapped, # MMA reads K[2j] and K[2j+1] swapped, which means V rows are swapped in pairs # Output = P @ V_with_row_pairs_swapped V_swap = kvf.clone() V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone() ref_swap = qf @ kvf.T @ V_swap # Also try: just the P@K^T with V where every 2 consecutive rows are swapped # but starting from different offsets # Try 1-based offset: swap rows (1,2), (3,4), ... V_swap1 = kvf.clone() V_swap1[1::2], V_swap1[2::2] = kvf[2::2].clone(), kvf[1::2].clone() V_swap1[0] = kvf[0] # row 0 unchanged ref_swap1 = qf @ kvf.T @ V_swap1 mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) print('Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mC, stream) print('Running...', flush=True) compiled(mQ, mK, mC, stream) torch.cuda.synchronize() out = c[:,:,0].float() cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item() cos_swap1 = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap1.flatten().unsqueeze(0)).item() print(f'\nCosine with standard ref: {cos_ref:.6f}') print(f'Cosine with pair-swapped ref: {cos_swap:.6f}') print(f'Cosine with offset-1 swapped: {cos_swap1:.6f}') # Also try: what if the entire P row is shifted by some amount? # Or what if P columns are reordered according to the A-fragment's K partitioning? # The A-fragment has K laid out as (16 inner, 4 outer chunks of 16) # If the store writes columns sequentially but the MMA reads them in the # chunked order, we'd get a specific permutation # The A-fragment K layout: (col, k0, k1) where col=0..15, k0=0..3, k1=0..1 # K position = col + 16*k0 + 64*k1 # If the store writes K sequentially (0,1,2,...,127) but the MMA reads # in the fragment order, the mapping would be: # Fragment index (col, k0, k1) -> K position (col + 16*k0 + 64*k1) # But this is sequential (0,1,2,...), so no permutation. # UNLESS the store doesn't fill the fragment's K blocks correctly. # Try: what if only the FIRST 64 K values (not 128) are filled? # The store writes 64 F32 = 128 BF16. But what if the MMA's A-fragment # K dimension is 64 (not 128)? Then only 64 K values have P, rest are garbage. # But we already verified nblk_pv=4 and the fragment covers 128 BF16. # Try: what if the 64 F32 columns in the store map to 128 BF16 K positions # but with the packing where BF16[K] and BF16[K+1] come from the same F32 word # and K is determined by the A-fragment's partition order? # The A-fragment reads K in groups of 16 (inner K), then 4 outer groups, then 2 BF16 per F32 # So K values are accessed as: [group0_bf16_0, group0_bf16_1, group1_bf16_0, group1_bf16_1, ...] # where group = (col, k0, k1) and bf16_0/bf16_1 are the two BF16 in each F32 word # This is getting complex. Let me just check the specific permutation by # matching output columns to reference columns for row 0. # For each output column j, find which reference column it matches # (using the entire row as a signature) print('\nTrying to identify the column permutation for row 0...') # Since all values might not be unique, use the dot product with a known vector # Actually, the simplest: for output[0, j], check if it equals ref[0, k] for any k # But we need a unique signature. Let's use the output of the ENTIRE row. # Compare output row i with reference rows for i in [0, 1, 2]: best_cos = 0 best_j = -1 for j in range(128): c = torch.nn.functional.cosine_similarity(out[i].unsqueeze(0), ref[j].unsqueeze(0)).item() if c > best_cos: best_cos = c best_j = j print(f' Output row {i} best matches ref row {best_j} (cos={best_cos:.4f})') # Also: compare output column j with reference columns print('\nColumn matching (output col j vs ref col k):') for j in [0, 1, 2, 3, 4, 5, 6, 7]: best_cos = 0 best_k = -1 for k in range(128): c = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item() if c > best_cos: best_cos = c best_k = k print(f' Output col {j} best matches ref col {best_k} (cos={best_cos:.4f})')