Key finding: C-fragment and A-fragment use different physical TMEM address mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16. Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999) Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02) Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
86 lines
3.3 KiB
Python
86 lines
3.3 KiB
Python
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
|
|
a_dtype = BFloat16; b_dtype = BFloat16
|
|
from cutlass.cute.nvgpu import OperandMajorMode
|
|
a_major = OperandMajorMode.K
|
|
b_major = OperandMajorMode.K
|
|
mma_tiler_mn = (128, 128)
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
a_dtype, b_dtype, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
a_dtype, b_dtype, OperandMajorMode.K, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
|
|
|
pv_thr = pv_mma.get_slice(0)
|
|
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler_mn)
|
|
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
|
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
|
|
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
pv_mma_tiler = (*mma_tiler_mn, pv_inst_k * 4)
|
|
|
|
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
|
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
|
tOrP_base = pv_thr.make_fragment_A(tP)
|
|
tOrP = tOrP_base[(None, None, None, 0)]
|
|
tOrP0 = cute.make_tensor(
|
|
tOrP.iterator + Float32.width // BFloat16.width * 32,
|
|
tOrP.layout)
|
|
|
|
print('=== Layout diagnostics ===')
|
|
print('tStS.layout:', tStS.layout)
|
|
print('tStS.size:', cute.size(tStS))
|
|
print('tStS s_cols:', find_tmem_tensor_col_offset(tStS))
|
|
print()
|
|
print('tOtO.layout:', tOtO.layout)
|
|
print('tOtO.size:', cute.size(tOtO))
|
|
print('tOtO o_cols:', find_tmem_tensor_col_offset(tOtO))
|
|
print()
|
|
print('tOrP.layout:', tOrP.layout)
|
|
print('tOrP.size:', cute.size(tOrP))
|
|
print('tOrP0.layout:', tOrP0.layout)
|
|
print('tOrP0.size:', cute.size(tOrP0))
|
|
print()
|
|
|
|
tilePlikeFP32 = 128 * 16 // 32
|
|
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
|
print('tStS_P_layout:', tStS_P_layout)
|
|
print()
|
|
|
|
# LOAD
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
|
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
|
|
thr_load = tiled_tmem_load.get_slice(0)
|
|
cS = cute.make_identity_tensor((128, 128))
|
|
tScS = qk_thr.partition_C(cS)
|
|
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
|
print('LOAD tTMEM_LOADcS.shape:', tTMEM_LOADcS.shape)
|
|
print('LOAD per-thread elements:', cute.size(tTMEM_LOADcS))
|
|
|
|
# STORE (composition)
|
|
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout)
|
|
tmem_store_atom = cute.make_copy_atom(
|
|
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
|
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
|
thr_store = tiled_tmem_store.get_slice(0)
|
|
tTMEM_STOREcS = thr_store.partition_S(cute.make_identity_tensor(tStS_P.shape))
|
|
print('STORE tTMEM_STOREcS.shape:', tTMEM_STOREcS.shape)
|
|
print('STORE per-thread elements:', cute.size(tTMEM_STOREcS))
|
|
|
|
# What about the tOrP0 shape for store?
|
|
print()
|
|
print('tOrP0.shape:', tOrP0.shape if hasattr(tOrP0, 'shape') else 'N/A')
|
|
# tOrP0 is BF16 so we'd need a BF16 store atom - but cute.copy requires equal bit widths
|
|
# The F32 store to a BF16 target doesn't work either
|
|
# This is the fundamental tension
|