- Remove vectorize=True from exp2 computation loop (carry variable)
- Add row_sum accumulation from P values in exp2 pass
- Compute row_max via fmax in separate pass
The .reduce() on the C-fragment gives global max across all rows,
not per-row max. Compute row_max element-wise from S values before
the exp2 pass. Also accumulate row_sum in the exp2 pass.
- Replace identity softmax with online softmax (row_max, exp2 scaling, P store)
- Add row_sum accumulation from P values
- After softmax loop, normalize O in TMEM by 1/row_sum using TMEM load/modify/store
- Then epilogue writes normalized O from TMEM to GMEM
- Reference test uses softmax(Q@K^T/sqrt(d))@V
- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling
Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
- test_fmha_v3_per_row.py: Mike's per-row patch with deadlock fix
(moved C6 O-rescale after softmax_done_bar, fixed pv_done_bar for kt=0)
Still GPU hangs — needs further debugging
- test_fmha_v3_fixed_v.py: s_k parameter + acc_pipe consumer fix
Same cosine as original (V TMA handles data shape correctly)
- Baseline: n=128→0.993, n=256→0.725, n=384→0.620
Key insight: QK TMEM load fragment has 4 rows × 32 cols per thread.
Fragment-level row_max/row_sum is wrong for per-row operations.
Per-row tracking (4 separate row_max/row_sum per thread) is needed.
- test_fmha_v3_scalar: direct acc_scale for C6 O-rescale (no vector)
- test_fmha_v3_vec_c9: TMEM vector for C9 row_sum transfer
- test_fmha_v3_noop_c9: hardcoded inv_row_sum=1.0 (no normalization)
- test_fmha_v3_debug: row_sum-based C9 normalization
- test_fmha_v3_proper: 11-warp correction warp group (in progress)
Key findings:
- QK and PV C-fragments map threads to same logical rows
- pv_row_sum (PV-based P read) gives cosine 0.993 for n=128
- row_sum (QK-accumulated) gives cosine 0.514 for n=128
- Noop (inv_row_sum=1.0) gives cosine 0.866 for n=128
- pv_row_sum is NOT 1.0 - it corrects PV MMA accumulator errors
- The C9 normalization is essential even for single-tile case
Pipeline init uses __syncthreads (all 320 threads participate).
Pipeline groups match 6-warp exactly.
Only difference: threads_per_cta=320 vs 192.
Direct comparison: 6-warp output [15,-129,-77.5,65,59]
vs 10-warp output [-7.5,2.2,-22.7,7.3,12.0] for row 0.
Completely different values.
Something in CuTe DSL runtime uses blockDim.x or total CTA size
in a way that breaks computation when CTA size changes from 192 to 320.
The pipeline_init_wait calls agent_sync(ThreadBlock) = __syncthreads
which all 320 threads reach. NamedBarriers use specific thread counts.
TMA atoms are created from MMA thread layout, not CTA size.
Hypothesis: the PipelineTmaUmma or PipelineUmmaAsync internally
uses blockDim.x for barrier arithmetic, making the barriers expect
more participants than the actual working threads.
Adding 4 idle warps (4-7) to 320-thread CTA:
- No crash, no deadlock (idle warps just pass)
- But output is garbage: cosine 0.29 vs 0.999999
Same softmax+MMA code, same TMEM layout, same barriers.
Only difference: mma_warp_id=8 (was 4), threads_per_cta=320 (was 192)
and 4 idle warps 4-7.
Something in the pipeline/barrier system assumes the old 6-warp topology.
Need to identify which component uses threads_per_cta or warp_idx
in a way that breaks with more warps.
C9 fix: instead of using QK-partitioned row_sum (which maps to wrong PV rows),
read P from TMEM using PV partition and sum via .reduce(ADD).
QK: thread N owns row N//4, PV: thread N owns row N.
Reading P via PV partition gives each thread its correct row P values.
n=128: cosine 0.993 (was 0.514)
n=256: cosine 0.725 (C6 still broken for multi-tile)
n=384: cosine 0.676 (same C6 issue)
Remaining: C6 O-rescale for multi-tile needs same PV-partitioned fix.
Small accuracy gap (0.993 vs 0.999) likely from BF16 P store/load round-trip.
row_sum is PROVEN correct (29.25 vs 29.22 for row 0, ratio 1.001).
The ONLY bug is QK→PV row mapping in C9 normalization.
Tried: composition(tStS,(128,1)) for write, composition(tOtO,(128,1)) for read.
Same result — the composition preserves the fragments internal thread-to-address
mapping, so the same thread writes and reads the same TMEM address regardless
of which fragment layout is used for the composition.
Need: absolute row-coordinate indexed TMEM vector. Each QK thread writes
inv_row_sum to vec[QK_row_id], each PV thread reads from vec[PV_row_id].
The row_id comes from the identity tensor coordinate.
Alternative: implement FMHA correction_epilog pattern with dedicated
correction warp group that reads row metadata from the vector.
The packed f32x2 reduction SHOULD sum all 128 exp2 P values but gives
a result ~5.3x too small. Need to debug inside the kernel with print
statements to see what values the reduction is actually summing.
Unnormalized P@V is perfect (cosine 0.999998). row_max is correct
(because P is correct). The bug is specifically in row_sum computation.
The C9 TMEM round-trip IS modifying O (confirmed by epilogue * 2.0 test).
But inv_row_sum is wrong: each thread computes row_sum via .reduce(MAX) and
packed f32x2 reduction, but the result appears to be the same for all threads.
Next: need to dump the QK C-fragment coordinate tensor to understand
which rows each thread actually owns in the TMEM load partition.
The QK composition(tStS, (128,64)) view of O TMEM region does not align
with the actual PV C-fragment layout. Cannot read O with QK partition.
Need to use TMEM vector approach:
1. Store inv_row_sum via QK partition (composition(tStS, (128,1)))
2. Read inv_row_sum via PV partition (need PV-partitioned view of vector)
3. Apply normalization in PV-partitioned O TMEM access
The key challenge: creating a PV-partitioned read of the vector TMEM region
that was written with QK partition. This is what CUTLASS FMHA does with
its correction warp group.
The softmax math (exp2, P store, PV) is correct for single-tile.
The bug is ONLY in C6/C9 normalization: applying inv_row_sum
using PV partition instead of QK partition.
n=128 (single tile): cosine 0.999998 PASS
n=256/384 (multi-tile): C6 O-rescale using wrong partition = FAIL
Fix: normalize O using QK row coordinates, not PV row coordinates.
Can use TMEM vector to bridge QK partition to PV partition.
Key finding: the root cause is that each epilogue thread owns MULTIPLE rows
in the QK C-fragment, so scalar row_max/row_sum are wrong (global across
all rows, not per-row). The V=ones diagnostic confirmed: all 128 threads
use the same row_sum (from row 114).
Tried: TMEM vector store+load of row_sum (composition(tStS, (128,2))).
This is a no-op because both write and read use the SAME QK partition
with a scalar row_sum. The vector approach only helps when different
partitions are used for write vs read, or when per-row values are stored.
Next steps:
1. Need PER-ROW row_max and row_sum, not per-thread scalar
2. The CUTLASS FMHA works because each thread owns exactly 1 row
3. Options: restructure thread layout, or compute per-row values differently
4. The vector must store ALL 128 per-row values, then read per-row in C9
Key finding: cute.size(v, mode=[0]) in @cute.jit produces wrong code.
Hardcoding s_k=128 (matching Stage B) fixes the base pipeline.
Current status: kernel produces non-zero output but softmax math is still wrong.
Applied fixes: pv_done_bar, acc_scale with scale, fastmath=True
Need to debug row_sum computation and C9 normalization.
- scale_softmax_log2 was missing from _setup (patch artifact)
- C9 normalization: load O from TMEM, multiply by 1/row_sum, store back
instead of trying to capture runtime value in const_expr lambda
- Then use standard epilogue_tma_store with identity transform
- C1: Real softmax reference (torch.softmax, not identity)
- C2: Per-thread row_max/row_sum registers
- C3: QK scale folded (1/sqrt(d) * log2(e))
- C4: Row max via .reduce(MAX)
- C5: Rescale factor (exp2(old_max - new_max))
- C6: O rescale in TMEM (correction_rescale pattern)
- C7: Real exp2 for P computation
- C8: Row sum via packed f32x2 reduction
- C9: Final normalization (1/row_sum in epilogue)
- Dynamic s_k for V FMHA reconstruction
- fastmath=False for correctness first
Root cause: PV output O started at TMEM column 64 (from find_tmem_tensor_col_offset),
overlapping with P at columns [32,96). PV MMA reading P while writing O to overlapping
columns corrupted the A operand mid-computation.
For (128,128) PV, O started at 128 (no overlap) so it worked by accident.
For (128,64) PV, O started at 64, overlapping P [32,96) -> NaN/garbage.
Fix: Place O at column 128 (after both S [0,128) and P [32,96)).
Also added FMHA-style V reconstruction: logical (HEAD_DIM, s_k, 1) stride (1, hd, hd*s_k)
instead of passing DLPack V directly to TMA.
test_fmha_v3.py: (128,64) PV with random V -> cosine 0.999999 PASS
Key findings:
- P/A alias WORKS: PV reads non-zero P from TMEM at offset 32 (proven by no-softmax test)
- V mode bug: V=(128,64) only loads 64 K-values, PV needs 128. Output = sum(S[:,:64]) = 0.67 cosine
- FMHA-style V reconstruction (hd,n,1) stride (1,hd) gives NaN for (128,64) PV
- K-major V (64,128) contiguous gives NaN for (128,64) PV
- Square (128,128) PV works with ALL V approaches (cosine 0.999999)
- Non-square PV consistently broken regardless of V layout
Test files:
- test_128_128_fmha_v.py: (128,128) with FMHA V - PASS
- test_pv64_fmha_v.py: (128,64) with FMHA V - NaN
- test_pv64_kmajor_v.py: (128,64) with K-major V - NaN
- test_pv64_with_softmax.py: (128,64) with original V - 0.67
- test_pv64_no_softmax.py: proves P/A alias works
- test_fmha_v3.py: full pipeline with QK C-fragment composition store
- test_pv64.py: (128,64) PV with separate V SMEM, single ab pipeline
Result: cosine 0.669848 — data path works but P layout mismatch
Softmax writes P via QK C-fragment layout, PV reads via PV A-fragment layout
These differ for non-(128,128) PV — Bug 1 from README
- test_fmha_v2_fixed.py: KV-tile interleaved pipeline with fixes
Fix 1: per-pipeline tx_count (Q vs KV separate byte counts)
Fix 2: NamedBarrier for softmax-done signal (replaces double-acquire deadlock)
Fix 3: Separate SMEM for V (no recast_ptr overlap with K)
Still produces zeros — needs P layout fix (same root cause as test_pv64)