""" BF16 Packing Diagnostic: Run identity softmax with K=V=randn, compare output vs reference to identify the error pattern. """ import torch, sys sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from test_stage_b_v7 import StageBIdentitySoftmax torch.manual_seed(42) m, n, k = 128, 128, 128 q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') qf = q[:,:,0].float() kvf = kv[:,:,0].float() ref = qf @ kvf.T @ kvf # identity softmax: (Q @ K^T) @ V mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) print('Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mC, stream) print('Running...', flush=True) compiled(mQ, mK, mC, stream) torch.cuda.synchronize() out = c[:,:,0].float() cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() print(f'\nCosine: {cos:.6f}') print(f'Output row 0[:8]: {out[0,:8].tolist()}') print(f'Ref row 0[:8]: {ref[0,:8].tolist()}') # Key diagnostic: compare Q@K^T stage (which we know is correct) vs PV stage # If Q@K^T is correct but PV is wrong, the output will show the PV error pattern # With identity softmax, P = Q@K^T. So output = P @ V = (Q@K^T) @ V # If V is being read wrong, the output will be P @ V_permuted # Check: is the output a permutation of the reference? # Sort both and compare out_sorted = out.flatten().sort()[0] ref_sorted = ref.flatten().sort()[0] cos_sorted = torch.nn.functional.cosine_similarity(out_sorted.unsqueeze(0), ref_sorted.unsqueeze(0)).item() print(f'\nCosine after sorting: {cos_sorted:.6f}') print(f'(If ~1.0, output is a permutation of reference)') # Check: does output match P @ V^T? (V transposed) ref_vt = qf @ kvf.T @ kvf.T # (Q@K^T) @ V^T cos_vt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_vt.flatten().unsqueeze(0)).item() print(f'Cosine with P @ V^T: {cos_vt:.6f}') # Check: does output match P^T @ V? (P transposed) # P = Q@K^T, so P^T = K@Q^T ref_pt = kvf @ qf.T @ kvf cos_pt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_pt.flatten().unsqueeze(0)).item() print(f'Cosine with P^T @ V: {cos_pt:.6f}') # Check: is output simply the Q@K^T scores (MMA2 produced identity)? # If MMA2 didn't run or produced P unchanged, output = P = Q@K^T qkt = qf @ kvf.T cos_qkt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), qkt.flatten().unsqueeze(0)).item() print(f'Cosine with just Q@K^T: {cos_qkt:.6f}') # Check: all output rows identical? (means P has identical rows, like all-ones) all_same = torch.allclose(out[0], out[1], atol=1e-3) print(f'All output rows identical: {all_same}') if not all_same: cos_r01 = torch.nn.functional.cosine_similarity(out[0].unsqueeze(0), out[1].unsqueeze(0)).item() print(f'Cosine between row 0 and row 1: {cos_r01:.6f}') # Check: is output just V (P=I case)? cos_v = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), kvf.flatten().unsqueeze(0)).item() print(f'Cosine with V alone: {cos_v:.6f}') # Print some output statistics print(f'\nOutput stats: min={out.min().item():.4f}, max={out.max().item():.4f}, mean={out.mean().item():.4f}') print(f'Ref stats: min={ref.min().item():.4f}, max={ref.max().item():.4f}, mean={ref.mean().item():.4f}')