docs: update README with Stage C TMEM layout mismatch findings and status

This commit is contained in:
2026-05-23 03:01:04 +00:00
parent f1ec406434
commit dbf76fbc87

151
README.md
View File

@@ -139,14 +139,13 @@ Summary
---
## Status (May 22, 2026 — 16:30 UTC)
## Status (May 23, 2026 — 02:55 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ⚠️ MULTI-TILE TMA FIXED | n=128 cos 0.999998 ✅. TMA fix: n=256 loads 2 tiles. Pipeline cycling needed for n≥384. O rescale needed. |
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
| C | ⚠️ SINGLE-TILE OK, MULTI-TILE 3% ERROR | n=128 cos 0.973. n=256 cos 0.793. TMEM layout mismatch between MMA C-fragment and get_tmem_load_op. See below. |
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
@@ -289,93 +288,94 @@ What it does:
---
## Stage C: Online Softmax — Multi-Tile In Progress
### What We Have
## Stage C: Online Softmax — TMEM Layout Mismatch Issue
**Working real softmax** for single KV tile (n=128): cosine 0.999998.
**Multi-tile TMA indexing fixed** (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug.
**Remaining:** O rescale between tiles, pipeline state cycling for n≥384, correction warps.
### Current Results (test_fmha_v3_stage_c.py)
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
| n | cos | Status |
|---|-----|--------|
| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch |
| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error |
| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles |
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: `(((64,128),1), ?, KV_tiles, ?)`.
### Root Cause: TMEM Layout Mismatch
**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` kept modes 0,1 free and set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses a **different TMEM column mapping**. When `epilogue_tma_store` reads O directly after PV (no normalize), the layout matches perfectly — **cos 0.999998** (raw PV output is correct).
**The fix:** `(None,0,None,0)` keeps modes 0,2 free, then `[None, kt]` indexes KV tiles:
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms:
```python
tBgK = tBgK[(None, 0, None, 0)]
cute.copy(tma_k, tBgK[None, kt], ...)
1. **NO-OP round-trip (load + store unchanged) → cos 0.973.** The hand-constructed atoms read/write using a different column mapping than `get_tmem_load_op`. The data gets "transcoded" — close but not exact.
2. **Normalize round-trip (load → multiply by 1/row_sum → store) → cos 0.973** (with preceding NO-OP) or **cos 0.465** (without NO-OP). Without the NO-OP, `epilogue_tma_store` reads the MMA layout directly and produces garbage.
3. **O rescale (kt > 0) + normalize → cos 0.793** at n=256. Each round-trip compounds the layout mismatch error.
### Why the NO-OP Round-Trip "Fixes" It
The MMA writes O in C-fragment TMEM layout. `epilogue_tma_store` reads in `get_tmem_load_op` layout. Without a round-trip, these layouts are incompatible → garbage output (cos 0.465).
A NO-OP round-trip through hand-constructed atoms reads the data using the hand-constructed layout (which can read the C-fragment data) and writes it back using the hand-constructed layout. After the round-trip, the data is in the hand-constructed layout, which is close to (but not identical to) the `get_tmem_load_op` layout → 3% error (cos 0.973).
### The Proper Fix: correction_epilog Pattern
The CUTLASS FMHA reference uses a one-way trip for the final epilogue:
```
TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM
```
**Results after TMA fix (verified on B200, May 22):**
- n=128: cos 0.999998 ✅
- n=256: cos 0.71 (TMA loads 2 tiles correctly, needs O rescale for 0.9999)
- n=512/1024: output identical to n=256 — pipeline not cycling past kv_stage=2
This reads O using `get_tmem_load_op` (same layout as `epilogue_tma_store`) and writes directly to SMEM. No TMEM round-trip. No layout mismatch. The CUTLASS reference uses this pattern and gets correct results.
**Why we can't use it yet:** The TMA store from SMEM → GMEM requires `tma_partition` / `flat_divide` which hit CuTeDSL region isolation errors when called inside `if warp_idx < self.mma_warp_id` blocks. The `epilogue_tma_store` helper works because it's a regular Python function that inlines into the same MLIR region — but it always reads from TMEM, not SMEM.
**Possible solutions:**
1. Call `epilogue_tma_store` but inject the `1/row_sum` multiply into its pipeline (requires modifying the helper or replicating it inline with the scale)
2. Pre-compute TMA partitioning outside the `if` block and pass the partitioned tensors through the kernel interface
3. Use the experimental `cutlass.cute.experimental.epilogue_tma_store` API which has a cleaner structure
### Verified Facts
- **Raw PV output is perfect:** `epilogue_tma_store` with identity op gives cos 0.999998 at n=128
- **Softmax P values are correct:** Unnormalized P@V matches reference exactly (cos 0.999998)
- **Online softmax computation is correct:** row_max and row_sum tracking works
- **The ONLY issue is the TMEM round-trip for normalize/rescale**
- **Stage A/B with identity softmax: cos 0.999999** — the pipeline works, softmax is the only addition
### Architecture (6-warp, current)
**Verified tensor shapes (diag prints inside @cute.kernel on B200, n=256):**
```
Before (None,0,None,0) pre-slice:
tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
tVgV: (((64,128),1), 1, 1, 1) — 4 modes
After (None,0,None,0) pre-slice:
tAgQ: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles
tBgK: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles
tVgV: (((64,128),1), 1) — 2 modes, mode 1 = 1 (static)
```
### Remaining for Multi-Tile
1. O rescale between tiles: `O *= exp2(old_max - new_max)` — needed for n=256+ to hit 0.9999
2. Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages) — output identical for all n>256, meaning only 2 KV tiles are loaded
3. Correction warps for production (separate softmax/correction/epilogue)
4. 12-warp layout
### Files
| File | Status | Notes |
|------|--------|-------|
| `fmha_v3_stage_c_example10.py` | 🔨 CURRENT | (None,0,None,0) TMA, combined K+V pipeline, O rescale, final normalize |
| `test_fmha_v3_stage_c_full.py` | OK n=128 | Working real softmax + O normalization |
| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks |
| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks |
| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline |
### Current Architecture (6-warp)
Warps 0-3: Softmax + Epilogue
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
### Target Architecture (12-warp, production)
Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty
### CuTeDSL Constraints (hard-won)
1. `vectorize=True` loops: ONLY load/store/print
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
4. `tBgK`/`tVgV` have 4 modes after tma_partition — (None,0,None,0) keeps mode 2 (KV tiles) free, [None, kt] indexes it
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
6. `softmax_done_bar` NamedBarrier is reusable across tiles
### Remaining for C' (Production Stage C)
1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
3. Cross-warp reduction for row_max and row_sum
4. Correction warps for multi-tile KV (online O rescale in TMEM)
5. 12-warp layout with separate softmax/correction/epilogue warps
```
### TMEM Layout
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)
```
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
```
### Remaining for Multi-Tile Production
1. **Fix TMEM layout mismatch** — replace hand-constructed atom round-trips with correction_epilog pattern
2. **Pipeline state cycling for n≥384** — kv_stage=2 can only buffer 2 tiles
3. **12-warp layout** — separate softmax/correction/epilogue warps
4. **O rescale for kt > 0** — must also use paired atoms or correction_epilog
---
## CuTeDSL Constraints (hard-won)
1. **`vectorize=True` loops: ONLY load/store/print** — no fmax, no cmpf, no inner loops, no carry
2. **`.reduce(cute.ReductionOp.MAX)`:** reduces ENTIRE C-fragment to scalar — global max, not per-row
3. **`cute.arch.fmax`:** impure for vectorizer — use plain `range()` loop
4. **TMA partition tensors have 4 modes:** `(((64,128),1), ?, KV_tiles, ?)``(None,0,None,0)` keeps mode 2 (KV tiles) free, `[None, kt]` indexes it
5. **`tBgK[(None, None, 0, 0)]` pins mode 2 to 0** — silently reads tile 0 forever. Use `(None,0,None,0)` instead.
6. **`softmax_done_bar` NamedBarrier is reusable** across tiles
7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips.
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
---
@@ -389,6 +389,8 @@ Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc,
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
9. **PRINT THE SHAPES. ALWAYS.** Reasoning about TMEM layouts without evidence is how we waste days.
10. **Never assume TMEM round-trips are safe.** Verify with NO-OP tests before adding logic.
---
@@ -400,3 +402,4 @@ Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc,
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
- Local CUTLASS clone: `/home/openclaw/dev/cutlass`