Key finding: C-fragment and A-fragment use different physical TMEM address mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16. Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999) Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02) Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
169 lines
8.2 KiB
Python
169 lines
8.2 KiB
Python
"""
|
|
Diagnostic: understand the C-layout vs A-layout TMEM address mapping.
|
|
|
|
After Stage A writes Q@K^T results to TMEM via C-fragment (tStS0),
|
|
read the same TMEM data back using the A-fragment (tOrP0) layout
|
|
and compare against reading it via the C-fragment layout.
|
|
|
|
This will tell us whether the composition-based store target
|
|
produces the right physical TMEM addresses for the A-fragment to read.
|
|
"""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
from cutlass import Float32, BFloat16, Int32, const_expr
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
|
|
def test():
|
|
torch.manual_seed(42)
|
|
m, n, k = 128, 128, 128
|
|
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
|
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
|
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
|
ref_qk = qf @ kvf.T # (128, 128)
|
|
|
|
import cutlass.torch as ct
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
|
c_layout = LayoutEnum.from_tensor(mC)
|
|
mma_tiler_mn = (128, 128)
|
|
mma_tiler = (*mma_tiler_mn, 1)
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, a_major, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, b_major,
|
|
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
|
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
|
|
pv_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
|
|
|
|
# Build the fragments
|
|
qk_thr = qk_thr = qk_mma.get_slice(0)
|
|
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
|
|
|
pv_thr = pv_mma.get_slice(0)
|
|
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
|
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
|
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
|
|
|
tilePlikeFP32 = qk_mma_tiler[1] * BFloat16.width // 32
|
|
tStS_P = cute.make_tensor(tStS.iterator + 32, cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))))
|
|
|
|
# Print the key layouts
|
|
print(f'tStS.layout: {tStS.layout}')
|
|
print(f'tStS.size: {cute.size(tStS)}')
|
|
print(f'tOrP.layout: {tOrP.layout}')
|
|
print(f'tOrP.size: {cute.size(tOrP)}')
|
|
print(f'tStS_P.layout: {tStS_P.layout}')
|
|
print(f'tP.layout: {tP.layout}')
|
|
print(f'tilePlikeFP32: {tilePlikeFP32}')
|
|
|
|
# For the C-fragment layout ((128,128),1,1):((65536,1),0,0)
|
|
# element at logical (m, n) maps to address 65536*m + n
|
|
# For the A-fragment layout ((128,16),1,4):((64,1),0,16)
|
|
# element at logical ((m_inner, k_inner), 1, mma_k) maps to address 64*m_inner + k_inner + 16*mma_k
|
|
|
|
# In the C-fragment, logical (m, col) = (m, col) with stride (65536, 1)
|
|
# In the A-fragment, logical ((m, k16), 1, block4) with stride ((64, 1), 0, 16)
|
|
# So A[m, k16, block4] -> address 64*m + k16 + 16*block4
|
|
# C[m, col] -> address 65536*m + col
|
|
|
|
# The KEY question: for the same physical TMEM column c and row m,
|
|
# what does C-layout say the logical index is, vs A-layout?
|
|
|
|
# C-layout: (m, col) -> addr = 65536*m + col
|
|
# A-layout: ((m, k16), 1, blk) -> addr = 64*m + k16 + 16*blk
|
|
|
|
# If both refer to the same physical TMEM, then:
|
|
# 65536*m_c + col = 64*m_a + k16 + 16*blk
|
|
# These can only be equal if the address spaces are different (C uses a different base)
|
|
# OR if the C-layout addresses wrap around modulo the TMEM size
|
|
|
|
# C-fragment cosize = 8323200 which is >> 16384 (actual data size)
|
|
# The C-fragment uses a STRIDED layout where stride-0 (M) = 65536
|
|
# This means the C-fragment is NOT a dense layout in TMEM address space
|
|
# The actual TMEM addresses used by the MMA hardware follow a different mapping
|
|
|
|
# Let's check: for M=128, N=128 C-fragment with ((128,128),1,1):((65536,1),0,0)
|
|
# The max address = 65536*127 + 127 = 8,321,023
|
|
# But TMEM only has 1MB = 256 columns * 4096 rows
|
|
# So addresses 65536*m must be modded or the layout is virtual
|
|
|
|
# AH. The C-fragment layout is VIRTUAL. cute.compile maps it to physical TMEM
|
|
# addresses when the actual MMA operation runs. The strides in the fragment layout
|
|
# are logical, not physical. The MMA hardware knows the physical column mapping.
|
|
|
|
# For the store, make_tmem_copy(St32x32bOp, tStS_P) creates a copy plan.
|
|
# The copy writes to tStS_P's layout addresses. These ARE the physical TMEM addresses
|
|
# that the MMA hardware will read from for the A-fragment.
|
|
|
|
# So the question becomes: is the COMPOSITION tStS_P correct?
|
|
# tStS.layout = ((128,128),1,1):((65536,1),0,0)
|
|
# composition with (128, 64) produces (128,64):(65536,1)
|
|
# This takes the first 64 columns of the C-layout.
|
|
|
|
# tOrP.layout = ((128,16),1,4):((64,1),0,16)
|
|
# This is the A-layout for 4 K=16 blocks.
|
|
|
|
# Let me compute what TMEM addresses each layout generates for the same (m, k) pair.
|
|
# For m in [0..127], for k in [0..63]:
|
|
# C-layout: logical (m, k) -> addr = 65536*m + k
|
|
# A-layout: k = 16*blk + k_inner
|
|
# logical ((m, k_inner), 1, blk) -> addr = 64*m + k_inner + 16*blk = 64*m + k
|
|
|
|
# So C-layout says addr = 65536*m + k
|
|
# A-layout says addr = 64*m + k
|
|
# These are DIFFERENT for m > 0.
|
|
|
|
# The C-fragment stride of 65536 is NOT the physical TMEM stride.
|
|
# Physical TMEM for M=128 has stride 64 (one 32B row per datapath, 64 FP32 columns)
|
|
# Wait, 128 BF16 = 256 bytes = 8 columns of 32B? No...
|
|
|
|
# TMEM is organized as columns of 128 rows of 32 bits each.
|
|
# For FP32 accumulator: each column holds 128 FP32 values (one per DP lane)
|
|
# For BF16 view: each FP32 column = 2 BF16 values packed? Or separate columns?
|
|
|
|
# The 64 stride in A-layout: 128 BF16 values need 64 FP32 columns (2 BF16 per FP32)
|
|
# stride-1 in the (128,16) sub-shape: (64,1) means m has stride 64, k_inner has stride 1
|
|
# 64 FP32 columns for 128 BF16 rows makes sense: each column holds 2 BF16, so 128/2=64 columns
|
|
|
|
# But wait, TMEM is 32-bit per entry per DP lane. For FP32, one column = 128 FP32 values.
|
|
# For BF16 A-fragment, each column stores 2 BF16 (high+low), so 128 BF16 = 64 columns.
|
|
# That matches the stride of 64 for m in the A-layout.
|
|
|
|
# For the C-fragment with FP32 accumulator, each column = 128 FP32 values.
|
|
# The C-layout for (128,128) FP32: stride (65536, 1)
|
|
# 65536 = 512 * 128. That's weird.
|
|
# Actually 65536 = 2^16. For 128 rows in FP32, each row in a column needs... hmm.
|
|
# The C-fragment stores FP32 values. stride-0 = 65536 means elements in the M direction
|
|
# are 65536 addresses apart. But we have 128*128 = 16384 elements total.
|
|
# cosize = 8323200 >> 16384. The layout is very sparse.
|
|
|
|
# I think the C-fragment layout addresses are NOT physical TMEM addresses.
|
|
# The MMA hardware and the ld/st copy atoms know how to map between logical and physical.
|
|
# The copy atom's partition_S/partition_D functions create the right physical mappings.
|
|
|
|
# The real question for Stage B is simpler:
|
|
# After Q@K^T writes to tStS0 (C-fragment), and we ld the scores,
|
|
# apply identity softmax, and st them to tStS_P (composed C-fragment)...
|
|
# does the PV MMA read the same data from tOrP0 (A-fragment)?
|
|
|
|
# fmha.py proves this works for F16. The composition pattern is identical.
|
|
# So either: (1) something about BF16 vs F16, or (2) something else in my code is wrong.
|
|
|
|
print('\nDiagnostic complete. Need to check if BF16 vs F16 affects the TMEM layout mapping.')
|
|
|
|
if __name__ == '__main__':
|
|
test()
|