README: Bug 4 ROOT CAUSE CONFIRMED - V SMEM 1 K-tile + PV 8 K-phases mismatch. Zero-pad V workaround correct.

This commit is contained in:
2026-05-21 09:59:37 +00:00
parent 6c0a6cf50b
commit 241ac2bf94

View File

@@ -31,31 +31,19 @@ The softmax writes P to TMEM using the **QK C-fragment layout**. The PV MMA read
- ❌ PV (128,16), V=I(128,128) → cosine 0.0 (all zeros)
- ❌ PV (128,16) with P at S offset (no softmax) → NaN (FP32→BF16 reinterpret)
### Root Cause (Updated May 21 09:20 UTC)
### Root Cause (CONFIRMED May 21 09:50 UTC)
**The P/A TMEM alias is NOT the bug.** Diagnostic prints confirm the PV A-fragment layout is IDENTICAL for all PV sizes:
**Bug is NOT the TMEM alias.** The PV A-fragment layout is identical for all PV sizes (confirmed by C++ source and diagnostics): all PV sizes produce tOrP2_s = (2048, 1, 8), size=16384.
```
(128,128) PV: tOrP2_s = (2048, 1, 8), size=16384, cosine=1.0
(128,32) PV: tOrP2_s = (2048, 1, 8), size=16384, cosine=0.51
(128,16) PV: tOrP2_s = (2048, 1, 8), size=16384, cosine=0.36
```
**The real bug: V SMEM only holds 1 K-tile (2048 BF16), but the PV MMA iterates 8 K-phases.** For non-(128,128) V, most K-phases read wrong V data.
The C++ source confirms: the A-fragment TMEM atom depends on M and K, NOT output N. The softmax writes P to the same TMEM columns regardless of PV size.
- (128,128) PV + V=I works by coincidence (V=I makes the projection self-consistent)
- (128,32) PV + V=(32,128) fails because V SMEM only has V[0:16,:], K-phases 1-7 read wrong data
- Zero-padded V works because V=(128,128) covers all 8 K-phases; rows beyond head_dim are zero
**The real bug is in the V/B staging or output C/D path.** When using (128,128) PV with zero-padded V (which keeps the V SMEM, O C-fragment, and epilogue at (128,128) dimensions), cosine=1.0. When using native (128,32) PV with V=(32,128), cosine=0.51. The difference is the V SMEM layout and/or output epilogue.
**How FMHA avoids this:** FMHA interleaves QK and PV per KV-tile. Each tile loads 16 K-rows of V, and PV processes only that tile. This ensures V SMEM always has the correct data.
**Key observations:**
- V smem_size=2048 for (128,16) PV, vs 16384 for (128,128) PV
- O tOtO_size=2048 for (128,16) PV, vs 16384 for (128,128) PV
- cta_tile_shape_mnk=(128,128,64) for BOTH — this is QK's cta tile, not PV's
- epi_tile=(128,16) for (128,16) PV — this IS correct (from PV)
- Swapping cta_tile to PV before epilogue doesn't fix the issue
**Next steps:**
1. Test V TMA load correctness for (128,32) PV — is V data loaded correctly into SMEM?
2. Test PV MMA output directly (skip epilogue) — is the PV MMA producing correct O?
3. Check if V B-operand fragment (tCrV) has the right shape for (128,32) PV
**Workaround:** Zero-pad V to 128 K-rows (2-4x compute waste, but correct). Proper fix: FMHA-style KV-tile interleaving.
### Current Workaround