The MMA loop (cutlass.range) and MMA consumer loop (range) also used
cute.size(gK, mode=[3]) which returns 1 for all n. Fixed all 3 loops:
1. TMA load loop (cutlass.range, line 215)
2. MMA consumer loop (range, line 231)
3. Softmax loop (range, line 324)
This was causing the deadlock — MMA only produced S[0] while softmax
waited for S[1].
cute.size(gK, mode=[3]) returns 1 for ALL n values — mode 3 is batch,
not KV tiles. self.n_kv_tiles = s_k // 128 is the correct Python int.
This is why softmax only processed kt=0 for all n.
Setup the correction_rescale atoms BEFORE the softmax loop so they can be
shared between per-tile O rescale and final normalize. Uses the working
2D register tensor pattern for final normalize. O rescale uses simple
1D rmem tensor per sub-tile (same as example10).
Previous O rescale attempt broke n=128 (0.464773).
Revert to known-good softmax code, only apply TMA fix:
tBgK[(None,None,0,0)] → tBgK[(None,0,None,0)]
Expected: n=128 cos 0.999998 (same as working), n=256 cos 0.71 (TMA fix loads 2 tiles but no O rescale)
The make_rmem_tensor(tTMEM_LOADcO.shape) creates a 1D tensor that doesn't
match the paired atom layout. The working pattern uses a 2D register tensor
with sub-tile composition (tTMrO_i_ = tTMrO[None, i] + composition).
- Moves correction_rescale atom setup before softmax loop (needed for O rescale)
- Adds O *= acc_scale for kt > 0, before softmax_done_bar.arrive()
- Uses same paired Ld32x32bOp/St32x32bOp(corr_tile_size=16) atoms as final normalize
- Final normalize (O *= 1/row_sum) uses same atoms, no duplicate setup
- Fixes softmax loop to use self.n_kv_tiles (Python int) not n_kv_tiles (CuTeDSL symbolic)
- This should fix n=256 cos 0.71 → 0.9999
The tma_partition output has 8 TMA coordinate dimensions, not 4.
The Python-visible shape shows 4 modes, but the TMA descriptor uses
8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are
collapsed and the GMEM tile axis (mode 4) is pinned to 0.
Pattern that works (confirmed on B200 at n=256 in diag test):
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # open 8D
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
The old 4-mode indexing tBgK[(None,None,kt,0)] fails with
'rank mismatch: got 2 and 1' because slicing a 4-mode tensor
produces wrong rank for the TMA coordinate space.
Matches working diag test test_fmha_v3_diag.py exactly.
The 8-mode indexing (tBgK[None,None,None,None,kt,None,None,None]) fails at
JIT compilation with 'coord and shape are weakly congruent' error. The actual
MLIR tensor shape is (((64,128),1),?,?,?) — 4 modes, not 8.
The working fix from commit 845ad98 on the B200 used 4-mode indexing all along:
tBgK[(None, None, kt, 0)] — mode 2 = GMEM tile dim
tVgV[(None, 0, kt, 0)] — mode 2 = GMEM tile dim
Updated all files: example10, test_fmha_v3_stage_c, README, docstrings.
K from QK MMA B-partition has GMEM iter at mode 1, NOT mode 2.
(None,0,None,0) hardcodes mode 1 to 0 → TMA always loads tile 0.
(None,None,0,0) keeps mode 1 free → correct multi-tile loading.
Proof: diag n=256 went from cos 0.711 → 0.999999 with this one change.
cute.size() returns a CuTeDSL symbol, not a Python int.
range() on a symbol can't iterate — the loop never unrolls.
Now n_kv_tiles is computed in __init__ as s_k // 128 (Python int).
cutlass.range traces once - kv_coord/kt are trace-time values,
not runtime loop-carried state. Python range() fully unrolls at
trace time, emitting distinct Int32(k) constants per iteration.
Int32(1) hardcoded already proved TMA CAN load from tile 1.
Key findings to relay to CUTLASS LLM:
- kv_coord=Int32(1) hardcode CHANGES the output (TMA CAN load from different tiles)
- kv_coord=Int32(0) + kv_coord += 1 does NOT increment at runtime
(all multi-tile outputs identical to kv_coord=0)
- kv_coord=0 (plain Python int) also doesn't work
- Pipeline handle .count doesn't work either
- The TMA GMEM tile coordinate must be dynamic at kernel runtime,
but CuTeDSL appears to constant-fold or not propagate the increment