The tma_partition output has 8 TMA coordinate dimensions, not 4.
The Python-visible shape shows 4 modes, but the TMA descriptor uses
8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are
collapsed and the GMEM tile axis (mode 4) is pinned to 0.
Pattern that works (confirmed on B200 at n=256 in diag test):
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # open 8D
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
The old 4-mode indexing tBgK[(None,None,kt,0)] fails with
'rank mismatch: got 2 and 1' because slicing a 4-mode tensor
produces wrong rank for the TMA coordinate space.
Matches working diag test test_fmha_v3_diag.py exactly.