Both GMEM and SMEM sides must be sliced to the same rank for cute.copy.
K (QK MMA B-partition): slice [(None,None,0,0)] keeps modes 0,1
- mode 1 = GMEM iteration, indexed by kvh.count
V (PV MMA B-partition): slice [(None,0,None,0)] keeps modes 0,2
- mode 2 = GMEM iteration, indexed by kvh.count
Q: only 1 tile, (None,0,None,0) hardcode is fine.
Root cause of multi-tile failure: (None,0,None,0) slice hardcodes the
GMEM tile dimension to 0, so TMA always loads from tile 0 regardless
of kvh.count. K from QK MMA has GMEM iter at mode 1, V from PV MMA
has it at mode 2 (different layouts: K,D,L vs D,K,L).
Fix follows CUTLASS FMHA reference:
- K: tBgK[(None,None,None,0)] + tBgK[(None, kvh.count, None)]
- V: tVgV[(None,0,None,0)] + tVgV[(None, kvh.count)]
row_max is already in log2(scaled) space (S * scale_log2), so
old_row_max - row_max_safe is the correct exponent for exp2.
The old code computed exp2(scale_log2 * (old_row_max - row_max_safe))
which is exp2(scale_log2^2 * (old_max_S - new_max_S)) — wrong.
- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
row_max is in scaled domain (s_val * scale_log2). The O rescaling
should be exp2(old_max - new_max) without extra scale_log2 because
the max values already include the scaling factor.
row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
- Move O TMEM load/store setup before softmax loop
- After P store: rescale O in TMEM by exp2((old_max - new_max) * scale)
- Only rescale for kt > 0 (first tile has no prior O to rescale)
- Use same TMEM load/modify/store pattern as final normalization
- Test both n=128 (1 tile) and n=256 (2 tiles)
MMA must wait for softmax to produce P in TMEM before starting PV.
Without this, MMA reads stale P data from TMEM, causing deadlock.
softmax_done_bar: softmax warps arrive after P store, MMA waits before PV.
- Add acc_bar to SS struct
- Create acc_pipe (full pipeline) before if blocks
- Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
- Replace identity softmax with online softmax (row_max, exp2 scaling, P store)
- Add row_sum accumulation from P values
- After softmax loop, normalize O in TMEM by 1/row_sum using TMEM load/modify/store
- Then epilogue writes normalized O from TMEM to GMEM
- Reference test uses softmax(Q@K^T/sqrt(d))@V
- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling
Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
- test_fmha_v3_per_row.py: Mike's per-row patch with deadlock fix
(moved C6 O-rescale after softmax_done_bar, fixed pv_done_bar for kt=0)
Still GPU hangs — needs further debugging
- test_fmha_v3_fixed_v.py: s_k parameter + acc_pipe consumer fix
Same cosine as original (V TMA handles data shape correctly)
- Baseline: n=128→0.993, n=256→0.725, n=384→0.620
Key insight: QK TMEM load fragment has 4 rows × 32 cols per thread.
Fragment-level row_max/row_sum is wrong for per-row operations.
Per-row tracking (4 separate row_max/row_sum per thread) is needed.
- test_fmha_v3_scalar: direct acc_scale for C6 O-rescale (no vector)
- test_fmha_v3_vec_c9: TMEM vector for C9 row_sum transfer
- test_fmha_v3_noop_c9: hardcoded inv_row_sum=1.0 (no normalization)
- test_fmha_v3_debug: row_sum-based C9 normalization
- test_fmha_v3_proper: 11-warp correction warp group (in progress)
Key findings:
- QK and PV C-fragments map threads to same logical rows
- pv_row_sum (PV-based P read) gives cosine 0.993 for n=128
- row_sum (QK-accumulated) gives cosine 0.514 for n=128
- Noop (inv_row_sum=1.0) gives cosine 0.866 for n=128
- pv_row_sum is NOT 1.0 - it corrects PV MMA accumulator errors
- The C9 normalization is essential even for single-tile case
Pipeline init uses __syncthreads (all 320 threads participate).
Pipeline groups match 6-warp exactly.
Only difference: threads_per_cta=320 vs 192.
Direct comparison: 6-warp output [15,-129,-77.5,65,59]
vs 10-warp output [-7.5,2.2,-22.7,7.3,12.0] for row 0.
Completely different values.
Something in CuTe DSL runtime uses blockDim.x or total CTA size
in a way that breaks computation when CTA size changes from 192 to 320.
The pipeline_init_wait calls agent_sync(ThreadBlock) = __syncthreads
which all 320 threads reach. NamedBarriers use specific thread counts.
TMA atoms are created from MMA thread layout, not CTA size.
Hypothesis: the PipelineTmaUmma or PipelineUmmaAsync internally
uses blockDim.x for barrier arithmetic, making the barriers expect
more participants than the actual working threads.
Adding 4 idle warps (4-7) to 320-thread CTA:
- No crash, no deadlock (idle warps just pass)
- But output is garbage: cosine 0.29 vs 0.999999
Same softmax+MMA code, same TMEM layout, same barriers.
Only difference: mma_warp_id=8 (was 4), threads_per_cta=320 (was 192)
and 4 idle warps 4-7.
Something in the pipeline/barrier system assumes the old 6-warp topology.
Need to identify which component uses threads_per_cta or warp_idx
in a way that breaks with more warps.