Files
nvfp4-megamoe-kernel/tests/test_tmem_layout_diag.py
biondizzle 1eca10e39f Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

169 lines
8.2 KiB
Python

"""
Diagnostic: understand the C-layout vs A-layout TMEM address mapping.
After Stage A writes Q@K^T results to TMEM via C-fragment (tStS0),
read the same TMEM data back using the A-fragment (tOrP0) layout
and compare against reading it via the C-fragment layout.
This will tell us whether the composition-based store target
produces the right physical TMEM addresses for the A-fragment to read.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref_qk = qf @ kvf.T # (128, 128)
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
c_layout = LayoutEnum.from_tensor(mC)
mma_tiler_mn = (128, 128)
mma_tiler = (*mma_tiler_mn, 1)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, b_major,
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
pv_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
# Build the fragments
qk_thr = qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
pv_thr = pv_mma.get_slice(0)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tilePlikeFP32 = qk_mma_tiler[1] * BFloat16.width // 32
tStS_P = cute.make_tensor(tStS.iterator + 32, cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))))
# Print the key layouts
print(f'tStS.layout: {tStS.layout}')
print(f'tStS.size: {cute.size(tStS)}')
print(f'tOrP.layout: {tOrP.layout}')
print(f'tOrP.size: {cute.size(tOrP)}')
print(f'tStS_P.layout: {tStS_P.layout}')
print(f'tP.layout: {tP.layout}')
print(f'tilePlikeFP32: {tilePlikeFP32}')
# For the C-fragment layout ((128,128),1,1):((65536,1),0,0)
# element at logical (m, n) maps to address 65536*m + n
# For the A-fragment layout ((128,16),1,4):((64,1),0,16)
# element at logical ((m_inner, k_inner), 1, mma_k) maps to address 64*m_inner + k_inner + 16*mma_k
# In the C-fragment, logical (m, col) = (m, col) with stride (65536, 1)
# In the A-fragment, logical ((m, k16), 1, block4) with stride ((64, 1), 0, 16)
# So A[m, k16, block4] -> address 64*m + k16 + 16*block4
# C[m, col] -> address 65536*m + col
# The KEY question: for the same physical TMEM column c and row m,
# what does C-layout say the logical index is, vs A-layout?
# C-layout: (m, col) -> addr = 65536*m + col
# A-layout: ((m, k16), 1, blk) -> addr = 64*m + k16 + 16*blk
# If both refer to the same physical TMEM, then:
# 65536*m_c + col = 64*m_a + k16 + 16*blk
# These can only be equal if the address spaces are different (C uses a different base)
# OR if the C-layout addresses wrap around modulo the TMEM size
# C-fragment cosize = 8323200 which is >> 16384 (actual data size)
# The C-fragment uses a STRIDED layout where stride-0 (M) = 65536
# This means the C-fragment is NOT a dense layout in TMEM address space
# The actual TMEM addresses used by the MMA hardware follow a different mapping
# Let's check: for M=128, N=128 C-fragment with ((128,128),1,1):((65536,1),0,0)
# The max address = 65536*127 + 127 = 8,321,023
# But TMEM only has 1MB = 256 columns * 4096 rows
# So addresses 65536*m must be modded or the layout is virtual
# AH. The C-fragment layout is VIRTUAL. cute.compile maps it to physical TMEM
# addresses when the actual MMA operation runs. The strides in the fragment layout
# are logical, not physical. The MMA hardware knows the physical column mapping.
# For the store, make_tmem_copy(St32x32bOp, tStS_P) creates a copy plan.
# The copy writes to tStS_P's layout addresses. These ARE the physical TMEM addresses
# that the MMA hardware will read from for the A-fragment.
# So the question becomes: is the COMPOSITION tStS_P correct?
# tStS.layout = ((128,128),1,1):((65536,1),0,0)
# composition with (128, 64) produces (128,64):(65536,1)
# This takes the first 64 columns of the C-layout.
# tOrP.layout = ((128,16),1,4):((64,1),0,16)
# This is the A-layout for 4 K=16 blocks.
# Let me compute what TMEM addresses each layout generates for the same (m, k) pair.
# For m in [0..127], for k in [0..63]:
# C-layout: logical (m, k) -> addr = 65536*m + k
# A-layout: k = 16*blk + k_inner
# logical ((m, k_inner), 1, blk) -> addr = 64*m + k_inner + 16*blk = 64*m + k
# So C-layout says addr = 65536*m + k
# A-layout says addr = 64*m + k
# These are DIFFERENT for m > 0.
# The C-fragment stride of 65536 is NOT the physical TMEM stride.
# Physical TMEM for M=128 has stride 64 (one 32B row per datapath, 64 FP32 columns)
# Wait, 128 BF16 = 256 bytes = 8 columns of 32B? No...
# TMEM is organized as columns of 128 rows of 32 bits each.
# For FP32 accumulator: each column holds 128 FP32 values (one per DP lane)
# For BF16 view: each FP32 column = 2 BF16 values packed? Or separate columns?
# The 64 stride in A-layout: 128 BF16 values need 64 FP32 columns (2 BF16 per FP32)
# stride-1 in the (128,16) sub-shape: (64,1) means m has stride 64, k_inner has stride 1
# 64 FP32 columns for 128 BF16 rows makes sense: each column holds 2 BF16, so 128/2=64 columns
# But wait, TMEM is 32-bit per entry per DP lane. For FP32, one column = 128 FP32 values.
# For BF16 A-fragment, each column stores 2 BF16 (high+low), so 128 BF16 = 64 columns.
# That matches the stride of 64 for m in the A-layout.
# For the C-fragment with FP32 accumulator, each column = 128 FP32 values.
# The C-layout for (128,128) FP32: stride (65536, 1)
# 65536 = 512 * 128. That's weird.
# Actually 65536 = 2^16. For 128 rows in FP32, each row in a column needs... hmm.
# The C-fragment stores FP32 values. stride-0 = 65536 means elements in the M direction
# are 65536 addresses apart. But we have 128*128 = 16384 elements total.
# cosize = 8323200 >> 16384. The layout is very sparse.
# I think the C-fragment layout addresses are NOT physical TMEM addresses.
# The MMA hardware and the ld/st copy atoms know how to map between logical and physical.
# The copy atom's partition_S/partition_D functions create the right physical mappings.
# The real question for Stage B is simpler:
# After Q@K^T writes to tStS0 (C-fragment), and we ld the scores,
# apply identity softmax, and st them to tStS_P (composed C-fragment)...
# does the PV MMA read the same data from tOrP0 (A-fragment)?
# fmha.py proves this works for F16. The composition pattern is identical.
# So either: (1) something about BF16 vs F16, or (2) something else in my code is wrong.
print('\nDiagnostic complete. Need to check if BF16 vs F16 affects the TMEM layout mapping.')
if __name__ == '__main__':
test()