Files
nvfp4-megamoe-kernel/tests/diag_layouts.py
biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00

86 lines
3.3 KiB
Python

import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
a_dtype = BFloat16; b_dtype = BFloat16
from cutlass.cute.nvgpu import OperandMajorMode
a_major = OperandMajorMode.K
b_major = OperandMajorMode.K
mma_tiler_mn = (128, 128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, OperandMajorMode.K, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler_mn)
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (*mma_tiler_mn, pv_inst_k * 4)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + Float32.width // BFloat16.width * 32,
tOrP.layout)
print('=== Layout diagnostics ===')
print('tStS.layout:', tStS.layout)
print('tStS.size:', cute.size(tStS))
print('tStS s_cols:', find_tmem_tensor_col_offset(tStS))
print()
print('tOtO.layout:', tOtO.layout)
print('tOtO.size:', cute.size(tOtO))
print('tOtO o_cols:', find_tmem_tensor_col_offset(tOtO))
print()
print('tOrP.layout:', tOrP.layout)
print('tOrP.size:', cute.size(tOrP))
print('tOrP0.layout:', tOrP0.layout)
print('tOrP0.size:', cute.size(tOrP0))
print()
tilePlikeFP32 = 128 * 16 // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
print('tStS_P_layout:', tStS_P_layout)
print()
# LOAD
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(0)
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print('LOAD tTMEM_LOADcS.shape:', tTMEM_LOADcS.shape)
print('LOAD per-thread elements:', cute.size(tTMEM_LOADcS))
# STORE (composition)
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(0)
tTMEM_STOREcS = thr_store.partition_S(cute.make_identity_tensor(tStS_P.shape))
print('STORE tTMEM_STOREcS.shape:', tTMEM_STOREcS.shape)
print('STORE per-thread elements:', cute.size(tTMEM_STOREcS))
# What about the tOrP0 shape for store?
print()
print('tOrP0.shape:', tOrP0.shape if hasattr(tOrP0, 'shape') else 'N/A')
# tOrP0 is BF16 so we'd need a BF16 store atom - but cute.copy requires equal bit widths
# The F32 store to a BF16 target doesn't work either
# This is the fundamental tension