cute.size() returns a CuTeDSL symbol, not a Python int.
range() on a symbol can't iterate — the loop never unrolls.
Now n_kv_tiles is computed in __init__ as s_k // 128 (Python int).
cutlass.range traces once - kv_coord/kt are trace-time values,
not runtime loop-carried state. Python range() fully unrolls at
trace time, emitting distinct Int32(k) constants per iteration.
Int32(1) hardcoded already proved TMA CAN load from tile 1.
Key findings to relay to CUTLASS LLM:
- kv_coord=Int32(1) hardcode CHANGES the output (TMA CAN load from different tiles)
- kv_coord=Int32(0) + kv_coord += 1 does NOT increment at runtime
(all multi-tile outputs identical to kv_coord=0)
- kv_coord=0 (plain Python int) also doesn't work
- Pipeline handle .count doesn't work either
- The TMA GMEM tile coordinate must be dynamic at kernel runtime,
but CuTeDSL appears to constant-fold or not propagate the increment
Both GMEM and SMEM sides must be sliced to the same rank for cute.copy.
K (QK MMA B-partition): slice [(None,None,0,0)] keeps modes 0,1
- mode 1 = GMEM iteration, indexed by kvh.count
V (PV MMA B-partition): slice [(None,0,None,0)] keeps modes 0,2
- mode 2 = GMEM iteration, indexed by kvh.count
Q: only 1 tile, (None,0,None,0) hardcode is fine.
Root cause of multi-tile failure: (None,0,None,0) slice hardcodes the
GMEM tile dimension to 0, so TMA always loads from tile 0 regardless
of kvh.count. K from QK MMA has GMEM iter at mode 1, V from PV MMA
has it at mode 2 (different layouts: K,D,L vs D,K,L).
Fix follows CUTLASS FMHA reference:
- K: tBgK[(None,None,None,0)] + tBgK[(None, kvh.count, None)]
- V: tVgV[(None,0,None,0)] + tVgV[(None, kvh.count)]
row_max is already in log2(scaled) space (S * scale_log2), so
old_row_max - row_max_safe is the correct exponent for exp2.
The old code computed exp2(scale_log2 * (old_row_max - row_max_safe))
which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
row_max is in scaled domain (s_val * scale_log2). The O rescaling
should be exp2(old_max - new_max) without extra scale_log2 because
the max values already include the scaling factor.
row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
- Move O TMEM load/store setup before softmax loop
- After P store: rescale O in TMEM by exp2((old_max - new_max) * scale)
- Only rescale for kt > 0 (first tile has no prior O to rescale)
- Use same TMEM load/modify/store pattern as final normalization
- Test both n=128 (1 tile) and n=256 (2 tiles)
MMA must wait for softmax to produce P in TMEM before starting PV.
Without this, MMA reads stale P data from TMEM, causing deadlock.
softmax_done_bar: softmax warps arrive after P store, MMA waits before PV.
- Add acc_bar to SS struct
- Create acc_pipe (full pipeline) before if blocks
- Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
- Replace identity softmax with online softmax (row_max, exp2 scaling, P store)
- Add row_sum accumulation from P values
- After softmax loop, normalize O in TMEM by 1/row_sum using TMEM load/modify/store
- Then epilogue writes normalized O from TMEM to GMEM
- Reference test uses softmax(Q@K^T/sqrt(d))@V
- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling
Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
- test_fmha_v3_per_row.py: Mike's per-row patch with deadlock fix
(moved C6 O-rescale after softmax_done_bar, fixed pv_done_bar for kt=0)
Still GPU hangs — needs further debugging
- test_fmha_v3_fixed_v.py: s_k parameter + acc_pipe consumer fix
Same cosine as original (V TMA handles data shape correctly)
- Baseline: n=128→0.993, n=256→0.725, n=384→0.620
Key insight: QK TMEM load fragment has 4 rows × 32 cols per thread.
Fragment-level row_max/row_sum is wrong for per-row operations.
Per-row tracking (4 separate row_max/row_sum per thread) is needed.
- test_fmha_v3_scalar: direct acc_scale for C6 O-rescale (no vector)
- test_fmha_v3_vec_c9: TMEM vector for C9 row_sum transfer
- test_fmha_v3_noop_c9: hardcoded inv_row_sum=1.0 (no normalization)
- test_fmha_v3_debug: row_sum-based C9 normalization
- test_fmha_v3_proper: 11-warp correction warp group (in progress)
Key findings:
- QK and PV C-fragments map threads to same logical rows
- pv_row_sum (PV-based P read) gives cosine 0.993 for n=128
- row_sum (QK-accumulated) gives cosine 0.514 for n=128
- Noop (inv_row_sum=1.0) gives cosine 0.866 for n=128
- pv_row_sum is NOT 1.0 - it corrects PV MMA accumulator errors
- The C9 normalization is essential even for single-tile case