diff --git a/README.md b/README.md index 9cfc677a..4b5017de 100644 --- a/README.md +++ b/README.md @@ -31,31 +31,19 @@ The softmax writes P to TMEM using the **QK C-fragment layout**. The PV MMA read - ❌ PV (128,16), V=I(128,128) → cosine 0.0 (all zeros) - ❌ PV (128,16) with P at S offset (no softmax) → NaN (FP32→BF16 reinterpret) -### Root Cause (Updated May 21 09:20 UTC) +### Root Cause (CONFIRMED May 21 09:50 UTC) -**The P/A TMEM alias is NOT the bug.** Diagnostic prints confirm the PV A-fragment layout is IDENTICAL for all PV sizes: +**Bug is NOT the TMEM alias.** The PV A-fragment layout is identical for all PV sizes (confirmed by C++ source and diagnostics): all PV sizes produce tOrP2_s = (2048, 1, 8), size=16384. -``` -(128,128) PV: tOrP2_s = (2048, 1, 8), size=16384, cosine=1.0 -(128,32) PV: tOrP2_s = (2048, 1, 8), size=16384, cosine=0.51 -(128,16) PV: tOrP2_s = (2048, 1, 8), size=16384, cosine=0.36 -``` +**The real bug: V SMEM only holds 1 K-tile (2048 BF16), but the PV MMA iterates 8 K-phases.** For non-(128,128) V, most K-phases read wrong V data. -The C++ source confirms: the A-fragment TMEM atom depends on M and K, NOT output N. The softmax writes P to the same TMEM columns regardless of PV size. +- (128,128) PV + V=I works by coincidence (V=I makes the projection self-consistent) +- (128,32) PV + V=(32,128) fails because V SMEM only has V[0:16,:], K-phases 1-7 read wrong data +- Zero-padded V works because V=(128,128) covers all 8 K-phases; rows beyond head_dim are zero -**The real bug is in the V/B staging or output C/D path.** When using (128,128) PV with zero-padded V (which keeps the V SMEM, O C-fragment, and epilogue at (128,128) dimensions), cosine=1.0. When using native (128,32) PV with V=(32,128), cosine=0.51. The difference is the V SMEM layout and/or output epilogue. +**How FMHA avoids this:** FMHA interleaves QK and PV per KV-tile. Each tile loads 16 K-rows of V, and PV processes only that tile. This ensures V SMEM always has the correct data. -**Key observations:** -- V smem_size=2048 for (128,16) PV, vs 16384 for (128,128) PV -- O tOtO_size=2048 for (128,16) PV, vs 16384 for (128,128) PV -- cta_tile_shape_mnk=(128,128,64) for BOTH — this is QK's cta tile, not PV's -- epi_tile=(128,16) for (128,16) PV — this IS correct (from PV) -- Swapping cta_tile to PV before epilogue doesn't fix the issue - -**Next steps:** -1. Test V TMA load correctness for (128,32) PV — is V data loaded correctly into SMEM? -2. Test PV MMA output directly (skip epilogue) — is the PV MMA producing correct O? -3. Check if V B-operand fragment (tCrV) has the right shape for (128,32) PV +**Workaround:** Zero-pad V to 128 K-rows (2-4x compute waste, but correct). Proper fix: FMHA-style KV-tile interleaving. ### Current Workaround