- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
- Stage C works for n=128 (0.993) but multi-tile (n>128) is broken
- Root cause: tBgK slice hardcodes GMEM iteration to tile 0
- CuTeDSL TMA copy doesn't accept Python int as tile index
- Mike's combined K+V barrier fix compiles but deadlocks at runtime
- Fallback: kh.count // 2 (untested)
Key changes from Mike:
1. Combined K+V TMA barrier: one acquire per kt, both cute.copys share
kvh.barrier. kvh.count naturally == kt (no interleaving problem).
tx_count = K_bytes + V_bytes. Also fixes the sK[0]/sV[1] slot quirk.
2. final_o_bar NamedBarrier: MMA .arrive() after acc_pipe.producer_tail;
softmax .arrive_and_wait() before reading O for normalize. Prevents
softmax racing MMA's PV[N-1] on the final O read.
3. acc_pipe producer in MMA: producer_acquire before loop, commit+advance
after loop, producer_tail after. Consumer in epilogue as before.
4. O rescale re-enabled for kt>0 with acc_scale before softmax_done_bar.
CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing.
kh.count works for single tile. For multi-tile, need to either:
1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V)
2. Separate K and V into non-interleaved TMA loops
3. Use gK/gV layouts that iterate naturally with pipeline count
This is the architectural blocker for multi-tile FMHA.
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
The slice (None,0,None,0) was hardcoding the GMEM iteration dim to 0,
meaning TMA always loaded K/V from tile 0 regardless of kt.
Changed to (None,None,None,0) to keep gmem_iter free,
then index with (None, kt, None) in the TMA copy loop.
This is the root cause of multi-tile failure: TMA was always reading
the first 128 tokens for ALL KV tiles.
Race condition: softmax reads O to normalize while MMA may still be
writing PV[N-1]. Single-tile wins by luck; multi-tile drifts.
Move acc_cons_st construction before the wait so epilogue reuses it.
Fix 1: s_k must equal actual n. With s_k < n, v_fmha layout only spans
first s_k V tokens and TMA reads OOB on later tiles.
Fix 2: TMA producer indexes K and V by kt (loop variable), NOT by the
pipeline's interleaved count. The kv pipeline interleaves K and V, so
pipeline count goes 0,1,2,3 but GMEM tiles should be K[0],V[0],K[1],V[1].
Fix 3: Online O rescale before softmax_done_bar. When row_max grows,
O must be multiplied by exp2(old_max - new_max) before MMA starts next PV.
row_max is in scaled domain (s_val * scale_log2). The O rescaling
should be exp2(old_max - new_max) without extra scale_log2 because
the max values already include the scaling factor.
row_max should be the max of the raw QK scores, not pre-scaled.
The scale_log2 is applied during exp2 and rescaling, not stored in row_max.
This fixes the double-scaling bug that broke multi-tile O rescaling.
- Move O TMEM load/store setup before softmax loop
- After P store: rescale O in TMEM by exp2((old_max - new_max) * scale)
- Only rescale for kt > 0 (first tile has no prior O to rescale)
- Use same TMEM load/modify/store pattern as final normalization
- Test both n=128 (1 tile) and n=256 (2 tiles)
MMA must wait for softmax to produce P in TMEM before starting PV.
Without this, MMA reads stale P data from TMEM, causing deadlock.
softmax_done_bar: softmax warps arrive after P store, MMA waits before PV.
- Add acc_bar to SS struct
- Create acc_pipe (full pipeline) before if blocks
- Pass acc_pipe to epilogue_tma_store (needs full pipeline, not participant)
- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
- Replace identity softmax with online softmax (row_max, exp2 scaling, P store)
- Add row_sum accumulation from P values
- After softmax loop, normalize O in TMEM by 1/row_sum using TMEM load/modify/store
- Then epilogue writes normalized O from TMEM to GMEM
- Reference test uses softmax(Q@K^T/sqrt(d))@V
- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling
Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
- test_fmha_v3_per_row.py: Mike's per-row patch with deadlock fix
(moved C6 O-rescale after softmax_done_bar, fixed pv_done_bar for kt=0)
Still GPU hangs — needs further debugging
- test_fmha_v3_fixed_v.py: s_k parameter + acc_pipe consumer fix
Same cosine as original (V TMA handles data shape correctly)
- Baseline: n=128→0.993, n=256→0.725, n=384→0.620
Key insight: QK TMEM load fragment has 4 rows × 32 cols per thread.
Fragment-level row_max/row_sum is wrong for per-row operations.
Per-row tracking (4 separate row_max/row_sum per thread) is needed.
- test_fmha_v3_scalar: direct acc_scale for C6 O-rescale (no vector)
- test_fmha_v3_vec_c9: TMEM vector for C9 row_sum transfer
- test_fmha_v3_noop_c9: hardcoded inv_row_sum=1.0 (no normalization)
- test_fmha_v3_debug: row_sum-based C9 normalization
- test_fmha_v3_proper: 11-warp correction warp group (in progress)
Key findings:
- QK and PV C-fragments map threads to same logical rows
- pv_row_sum (PV-based P read) gives cosine 0.993 for n=128
- row_sum (QK-accumulated) gives cosine 0.514 for n=128
- Noop (inv_row_sum=1.0) gives cosine 0.866 for n=128
- pv_row_sum is NOT 1.0 - it corrects PV MMA accumulator errors
- The C9 normalization is essential even for single-tile case
Root cause of Xid 13 crash: extern __shared__ with reinterpret_cast
chain caused alignment faults on SM100. Switched to static __shared__
arrays (s_heap_scores[1024], s_heap_blocks[1024], s_w[64], s_lock).
Also fixed the FP4 key addressing: keys are stored flat as
[num_blocks, epb, n_h*c_I/2] total bytes per entry. Head h starts
at byte offset h*(c_I/2) and group offset h*(c_I/16) within each
entry. Previous code used per-head n_groups indexing which was wrong
for the flat layout.
Kernel now runs successfully on B200. FP4 quantization noise causes
ranking differences vs FP32 oracle (expected — the tcgen05 FP4 MMA
path with FP32 accumulation will fix this). Top-k structure and heap
logic verified correct via separate heap-only test (exact match vs torch.topk).
Pipeline init uses __syncthreads (all 320 threads participate).
Pipeline groups match 6-warp exactly.
Only difference: threads_per_cta=320 vs 192.
Direct comparison: 6-warp output [15,-129,-77.5,65,59]
vs 10-warp output [-7.5,2.2,-22.7,7.3,12.0] for row 0.
Completely different values.
Something in CuTe DSL runtime uses blockDim.x or total CTA size
in a way that breaks computation when CTA size changes from 192 to 320.
The pipeline_init_wait calls agent_sync(ThreadBlock) = __syncthreads
which all 320 threads reach. NamedBarriers use specific thread counts.
TMA atoms are created from MMA thread layout, not CTA size.
Hypothesis: the PipelineTmaUmma or PipelineUmmaAsync internally
uses blockDim.x for barrier arithmetic, making the barriers expect
more participants than the actual working threads.