Key finding: C-fragment and A-fragment use different physical TMEM address mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16. Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999) Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02) Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
134 lines
5.8 KiB
Python
134 lines
5.8 KiB
Python
"""
|
|
BF16 Packing Diagnostic: Write specific F32 bit patterns to P TMEM via St32x32bOp.
|
|
Then run PV MMA with V=identity. Output = P (as MMA reads it).
|
|
|
|
This reveals the BF16 packing order within F32 TMEM words.
|
|
|
|
Strategy:
|
|
- Skip MMA1 entirely (no QK computation)
|
|
- Write packed F32 values where low BF16 ≠ high BF16
|
|
- Use K=V=identity so output[i,j] = P[i,j]
|
|
- Compare output against both packing orders
|
|
"""
|
|
import torch, struct, sys
|
|
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
|
|
|
# Use the existing diag test that writes P=all-ones, but modify to write specific patterns
|
|
# The diag test is test_stage_b_diag.py - let me copy and modify v7 instead
|
|
|
|
# Actually, let me just use the EXISTING kernel (test_stage_b_v7) with:
|
|
# 1. K = identity matrix
|
|
# 2. The identity softmax running normally (it writes softmax(Q@K^T) as P)
|
|
# 3. If Q is also identity, then Q@K^T = I@I = I, softmax(I) = softmax of identity
|
|
# 4. That's not what we want either.
|
|
|
|
# Better: use the diag test (writes P=all-ones) with K=identity
|
|
# Output should be all-ones * I = all-ones. That gives no new info.
|
|
|
|
# The REAL test: modify the kernel to write specific F32 values to the BF16 recast view
|
|
# Instead of writing 1.0 BF16, write alternating different values
|
|
|
|
# Simplest approach: modify test_stage_b_diag.py to write j+1 in BF16 for each position
|
|
# Then with V=identity, output[i,j] = j+1
|
|
|
|
# But modifying the kernel requires JIT changes. Let me use the existing v7 kernel
|
|
# with a clever input choice instead.
|
|
|
|
# Approach: Use the identity softmax (which writes Q@K^T scores as P)
|
|
# With Q=randn and K=identity: Q@K^T = Q (since K=I)
|
|
# Then P = softmax(Q) and output = softmax(Q) @ V
|
|
# With V=I: output = softmax(Q)
|
|
# This tests the FULL pipeline but doesn't isolate packing.
|
|
|
|
# Let me just write the packing test kernel from scratch, minimally.
|
|
# It only needs: TMA load V, write F32 to P TMEM, PV MMA, epilogue.
|
|
|
|
# Actually, the simplest isolation test:
|
|
# Use the existing test with P=all-ones and V=K (K=randn)
|
|
# We already know cos=0.08 for this. The 0.08 is not 0 or 1.
|
|
# If I use V where V[j] = 1 if j even, 0 if j odd, then:
|
|
# With correct packing: output = (number of ones in even K positions) per row
|
|
# This doesn't help because P=all-ones means all K positions are 1.0
|
|
|
|
# I think the key issue is that with P=all-ones (cos=0.08), the output should be
|
|
# EXACTLY sum(V, dim=0) for each row. Let me compare more carefully.
|
|
|
|
# Let me run the EXISTING P=all-ones diag test and compare the output values
|
|
# against the reference. The PATTERN of errors will tell us about packing.
|
|
|
|
print("Running existing P=all-ones diag with K=V=randn")
|
|
print("Comparing output vs reference to identify the error pattern")
|
|
|
|
import cutlass.cute as cute
|
|
import cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cutlass.torch as ct
|
|
import cuda.bindings.driver as cuda
|
|
|
|
torch.manual_seed(42)
|
|
m, n, k = 128, 128, 128
|
|
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
kvf = kv[:,:,0].float()
|
|
ref = torch.ones(128, 128, dtype=torch.float32, device='cuda') @ kvf
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
from test_stage_b_diag import StageBDiag
|
|
kernel = StageBDiag(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
|
print('Compiling diag test...', flush=True)
|
|
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
|
print('Running...', flush=True)
|
|
compiled(mQ, mK, mC, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
out = c[:,:,0].float()
|
|
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
|
|
|
print(f'\nCosine: {cos:.6f}')
|
|
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
|
|
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
|
|
print(f'Ratio out/ref row 0[:8]: {(out[0,:8] / ref[0,:8]).tolist()}')
|
|
|
|
# Check if output is a scaled version of the reference
|
|
ratio = out[0] / ref[0]
|
|
ratio_clean = ratio[torch.isfinite(ratio)]
|
|
print(f'Ratio mean: {ratio_clean.mean().item():.6f}, std: {ratio_clean.std().item():.6f}')
|
|
|
|
# Check if odd/even columns have different ratios
|
|
ratio_even = (out[0, 0::2] / ref[0, 0::2])[torch.isfinite(out[0, 0::2] / ref[0, 0::2])]
|
|
ratio_odd = (out[0, 1::2] / ref[0, 1::2])[torch.isfinite(out[0, 1::2] / ref[0, 1::2])]
|
|
print(f'Even col ratio mean: {ratio_even.mean().item():.6f}')
|
|
print(f'Odd col ratio mean: {ratio_odd.mean().item():.6f}')
|
|
|
|
# Check if output matches a different V interpretation
|
|
# What if the V SMEM is being read in a different column order?
|
|
# Compute: what if V columns were permuted?
|
|
# With P=all-ones, output[i] = sum(V[:,j]) for each j
|
|
# This is the column sum of V, broadcast to all rows
|
|
# If V is read in a different order, the output would be the sum of
|
|
# differently-ordered V columns, but still all the same sum
|
|
# So output should be uniform across columns = sum of all V values per column
|
|
# But the output IS uniform (all rows identical). The VALUES are wrong.
|
|
# That means the SUM is wrong, which means the V values being read are wrong.
|
|
|
|
# Let me check: does the output column structure match any known transform of V?
|
|
print(f'\nV (K) row 0[:8]: {kvf[0,:8].tolist()}')
|
|
print(f'V col sums (reference): {ref[0,:4].tolist()}')
|
|
print(f'Output col 0[:4]: {out[0,:4].tolist()}')
|
|
|
|
# What if V is being read transposed?
|
|
# sum(V[j,:]) per column j = row sums of V
|
|
row_sums = kvf.sum(dim=1)
|
|
col_sums = kvf.sum(dim=0)
|
|
print(f'V row sums[:4]: {row_sums[:4].tolist()}')
|
|
print(f'V col sums[:4]: {col_sums[:4].tolist()}')
|