Key finding: C-fragment and A-fragment use different physical TMEM address mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16. Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999) Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02) Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
120 lines
5.3 KiB
Python
120 lines
5.3 KiB
Python
"""
|
|
BF16 Pair-Swap Test: compare kernel output against reference with V rows
|
|
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
|
|
"""
|
|
import torch, sys
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
|
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
from test_stage_b_v7 import StageBIdentitySoftmax
|
|
|
|
torch.manual_seed(42)
|
|
m, n, k = 128, 128, 128
|
|
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
qf = q[:,:,0].float()
|
|
kvf = kv[:,:,0].float()
|
|
|
|
# Standard reference: P @ V where P = Q @ K^T
|
|
ref = qf @ kvf.T @ kvf
|
|
|
|
# Pair-swapped reference: if BF16 within each F32 are swapped,
|
|
# MMA reads K[2j] and K[2j+1] swapped, which means V rows are swapped in pairs
|
|
# Output = P @ V_with_row_pairs_swapped
|
|
V_swap = kvf.clone()
|
|
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
|
|
ref_swap = qf @ kvf.T @ V_swap
|
|
|
|
# Also try: just the P@K^T with V where every 2 consecutive rows are swapped
|
|
# but starting from different offsets
|
|
# Try 1-based offset: swap rows (1,2), (3,4), ...
|
|
V_swap1 = kvf.clone()
|
|
V_swap1[1::2], V_swap1[2::2] = kvf[2::2].clone(), kvf[1::2].clone()
|
|
V_swap1[0] = kvf[0] # row 0 unchanged
|
|
ref_swap1 = qf @ kvf.T @ V_swap1
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
|
print('Compiling...', flush=True)
|
|
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
|
print('Running...', flush=True)
|
|
compiled(mQ, mK, mC, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
out = c[:,:,0].float()
|
|
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
|
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
|
|
cos_swap1 = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap1.flatten().unsqueeze(0)).item()
|
|
|
|
print(f'\nCosine with standard ref: {cos_ref:.6f}')
|
|
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
|
|
print(f'Cosine with offset-1 swapped: {cos_swap1:.6f}')
|
|
|
|
# Also try: what if the entire P row is shifted by some amount?
|
|
# Or what if P columns are reordered according to the A-fragment's K partitioning?
|
|
# The A-fragment has K laid out as (16 inner, 4 outer chunks of 16)
|
|
# If the store writes columns sequentially but the MMA reads them in the
|
|
# chunked order, we'd get a specific permutation
|
|
|
|
# The A-fragment K layout: (col, k0, k1) where col=0..15, k0=0..3, k1=0..1
|
|
# K position = col + 16*k0 + 64*k1
|
|
# If the store writes K sequentially (0,1,2,...,127) but the MMA reads
|
|
# in the fragment order, the mapping would be:
|
|
# Fragment index (col, k0, k1) -> K position (col + 16*k0 + 64*k1)
|
|
# But this is sequential (0,1,2,...), so no permutation.
|
|
# UNLESS the store doesn't fill the fragment's K blocks correctly.
|
|
|
|
# Try: what if only the FIRST 64 K values (not 128) are filled?
|
|
# The store writes 64 F32 = 128 BF16. But what if the MMA's A-fragment
|
|
# K dimension is 64 (not 128)? Then only 64 K values have P, rest are garbage.
|
|
# But we already verified nblk_pv=4 and the fragment covers 128 BF16.
|
|
|
|
# Try: what if the 64 F32 columns in the store map to 128 BF16 K positions
|
|
# but with the packing where BF16[K] and BF16[K+1] come from the same F32 word
|
|
# and K is determined by the A-fragment's partition order?
|
|
|
|
# The A-fragment reads K in groups of 16 (inner K), then 4 outer groups, then 2 BF16 per F32
|
|
# So K values are accessed as: [group0_bf16_0, group0_bf16_1, group1_bf16_0, group1_bf16_1, ...]
|
|
# where group = (col, k0, k1) and bf16_0/bf16_1 are the two BF16 in each F32 word
|
|
|
|
# This is getting complex. Let me just check the specific permutation by
|
|
# matching output columns to reference columns for row 0.
|
|
|
|
# For each output column j, find which reference column it matches
|
|
# (using the entire row as a signature)
|
|
print('\nTrying to identify the column permutation for row 0...')
|
|
# Since all values might not be unique, use the dot product with a known vector
|
|
# Actually, the simplest: for output[0, j], check if it equals ref[0, k] for any k
|
|
# But we need a unique signature. Let's use the output of the ENTIRE row.
|
|
|
|
# Compare output row i with reference rows
|
|
for i in [0, 1, 2]:
|
|
best_cos = 0
|
|
best_j = -1
|
|
for j in range(128):
|
|
c = torch.nn.functional.cosine_similarity(out[i].unsqueeze(0), ref[j].unsqueeze(0)).item()
|
|
if c > best_cos:
|
|
best_cos = c
|
|
best_j = j
|
|
print(f' Output row {i} best matches ref row {best_j} (cos={best_cos:.4f})')
|
|
|
|
# Also: compare output column j with reference columns
|
|
print('\nColumn matching (output col j vs ref col k):')
|
|
for j in [0, 1, 2, 3, 4, 5, 6, 7]:
|
|
best_cos = 0
|
|
best_k = -1
|
|
for k in range(128):
|
|
c = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
|
|
if c > best_cos:
|
|
best_cos = c
|
|
best_k = k
|
|
print(f' Output col {j} best matches ref col {best_k} (cos={best_cos:.4f})')
|