Files
nvfp4-megamoe-kernel/tests/test_packing_diag.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

134 lines
5.8 KiB
Python

"""
BF16 Packing Diagnostic: Write specific F32 bit patterns to P TMEM via St32x32bOp.
Then run PV MMA with V=identity. Output = P (as MMA reads it).
This reveals the BF16 packing order within F32 TMEM words.
Strategy:
- Skip MMA1 entirely (no QK computation)
- Write packed F32 values where low BF16 ≠ high BF16
- Use K=V=identity so output[i,j] = P[i,j]
- Compare output against both packing orders
"""
import torch, struct, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
# Use the existing diag test that writes P=all-ones, but modify to write specific patterns
# The diag test is test_stage_b_diag.py - let me copy and modify v7 instead
# Actually, let me just use the EXISTING kernel (test_stage_b_v7) with:
# 1. K = identity matrix
# 2. The identity softmax running normally (it writes softmax(Q@K^T) as P)
# 3. If Q is also identity, then Q@K^T = I@I = I, softmax(I) = softmax of identity
# 4. That's not what we want either.
# Better: use the diag test (writes P=all-ones) with K=identity
# Output should be all-ones * I = all-ones. That gives no new info.
# The REAL test: modify the kernel to write specific F32 values to the BF16 recast view
# Instead of writing 1.0 BF16, write alternating different values
# Simplest approach: modify test_stage_b_diag.py to write j+1 in BF16 for each position
# Then with V=identity, output[i,j] = j+1
# But modifying the kernel requires JIT changes. Let me use the existing v7 kernel
# with a clever input choice instead.
# Approach: Use the identity softmax (which writes Q@K^T scores as P)
# With Q=randn and K=identity: Q@K^T = Q (since K=I)
# Then P = softmax(Q) and output = softmax(Q) @ V
# With V=I: output = softmax(Q)
# This tests the FULL pipeline but doesn't isolate packing.
# Let me just write the packing test kernel from scratch, minimally.
# It only needs: TMA load V, write F32 to P TMEM, PV MMA, epilogue.
# Actually, the simplest isolation test:
# Use the existing test with P=all-ones and V=K (K=randn)
# We already know cos=0.08 for this. The 0.08 is not 0 or 1.
# If I use V where V[j] = 1 if j even, 0 if j odd, then:
# With correct packing: output = (number of ones in even K positions) per row
# This doesn't help because P=all-ones means all K positions are 1.0
# I think the key issue is that with P=all-ones (cos=0.08), the output should be
# EXACTLY sum(V, dim=0) for each row. Let me compare more carefully.
# Let me run the EXISTING P=all-ones diag test and compare the output values
# against the reference. The PATTERN of errors will tell us about packing.
print("Running existing P=all-ones diag with K=V=randn")
print("Comparing output vs reference to identify the error pattern")
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
import cuda.bindings.driver as cuda
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
kvf = kv[:,:,0].float()
ref = torch.ones(128, 128, dtype=torch.float32, device='cuda') @ kvf
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
from test_stage_b_diag import StageBDiag
kernel = StageBDiag(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling diag test...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'\nCosine: {cos:.6f}')
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
print(f'Ratio out/ref row 0[:8]: {(out[0,:8] / ref[0,:8]).tolist()}')
# Check if output is a scaled version of the reference
ratio = out[0] / ref[0]
ratio_clean = ratio[torch.isfinite(ratio)]
print(f'Ratio mean: {ratio_clean.mean().item():.6f}, std: {ratio_clean.std().item():.6f}')
# Check if odd/even columns have different ratios
ratio_even = (out[0, 0::2] / ref[0, 0::2])[torch.isfinite(out[0, 0::2] / ref[0, 0::2])]
ratio_odd = (out[0, 1::2] / ref[0, 1::2])[torch.isfinite(out[0, 1::2] / ref[0, 1::2])]
print(f'Even col ratio mean: {ratio_even.mean().item():.6f}')
print(f'Odd col ratio mean: {ratio_odd.mean().item():.6f}')
# Check if output matches a different V interpretation
# What if the V SMEM is being read in a different column order?
# Compute: what if V columns were permuted?
# With P=all-ones, output[i] = sum(V[:,j]) for each j
# This is the column sum of V, broadcast to all rows
# If V is read in a different order, the output would be the sum of
# differently-ordered V columns, but still all the same sum
# So output should be uniform across columns = sum of all V values per column
# But the output IS uniform (all rows identical). The VALUES are wrong.
# That means the SUM is wrong, which means the V values being read are wrong.
# Let me check: does the output column structure match any known transform of V?
print(f'\nV (K) row 0[:8]: {kvf[0,:8].tolist()}')
print(f'V col sums (reference): {ref[0,:4].tolist()}')
print(f'Output col 0[:4]: {out[0,:4].tolist()}')
# What if V is being read transposed?
# sum(V[j,:]) per column j = row sums of V
row_sums = kvf.sum(dim=1)
col_sums = kvf.sum(dim=0)
print(f'V row sums[:4]: {row_sums[:4].tolist()}')
print(f'V col sums[:4]: {col_sums[:4].tolist()}')