diff --git a/README.md b/README.md index bd78b435..04069365 100644 --- a/README.md +++ b/README.md @@ -139,14 +139,13 @@ Summary --- -## Status (May 22, 2026 — 16:30 UTC) +## Status (May 23, 2026 — 02:55 UTC) | Stage | Status | Description | |-------|--------|-------------| | A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM | | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | -| C | ⚠️ MULTI-TILE TMA FIXED | n=128 cos 0.999998 ✅. TMA fix: n=256 loads 2 tiles. Pipeline cycling needed for n≥384. O rescale needed. | -| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. | +| C | ⚠️ SINGLE-TILE OK, MULTI-TILE 3% ERROR | n=128 cos 0.973. n=256 cos 0.793. TMEM layout mismatch between MMA C-fragment and get_tmem_load_op. See below. | | D | TODO | Full decode attention: paged KV cache, multi-query, causal mask | | E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge | @@ -289,93 +288,94 @@ What it does: --- -## Stage C: Online Softmax — Multi-Tile In Progress -### What We Have +## Stage C: Online Softmax — TMEM Layout Mismatch Issue -**Working real softmax** for single KV tile (n=128): cosine 0.999998. -**Multi-tile TMA indexing fixed** (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug. -**Remaining:** O rescale between tiles, pipeline state cycling for n≥384, correction warps. +### Current Results (test_fmha_v3_stage_c.py) -### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug) +| n | cos | Status | +|---|-----|--------| +| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch | +| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error | +| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles | -After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes**: `(((64,128),1), ?, KV_tiles, ?)`. +### Root Cause: TMEM Layout Mismatch -**Mode 2 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` kept modes 0,1 free and set mode 2 to 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error. +The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses a **different TMEM column mapping**. When `epilogue_tma_store` reads O directly after PV (no normalize), the layout matches perfectly — **cos 0.999998** (raw PV output is correct). -**The fix:** `(None,0,None,0)` keeps modes 0,2 free, then `[None, kt]` indexes KV tiles: +The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms: -```python -tBgK = tBgK[(None, 0, None, 0)] -cute.copy(tma_k, tBgK[None, kt], ...) +1. **NO-OP round-trip (load + store unchanged) → cos 0.973.** The hand-constructed atoms read/write using a different column mapping than `get_tmem_load_op`. The data gets "transcoded" — close but not exact. +2. **Normalize round-trip (load → multiply by 1/row_sum → store) → cos 0.973** (with preceding NO-OP) or **cos 0.465** (without NO-OP). Without the NO-OP, `epilogue_tma_store` reads the MMA layout directly and produces garbage. +3. **O rescale (kt > 0) + normalize → cos 0.793** at n=256. Each round-trip compounds the layout mismatch error. + +### Why the NO-OP Round-Trip "Fixes" It + +The MMA writes O in C-fragment TMEM layout. `epilogue_tma_store` reads in `get_tmem_load_op` layout. Without a round-trip, these layouts are incompatible → garbage output (cos 0.465). + +A NO-OP round-trip through hand-constructed atoms reads the data using the hand-constructed layout (which can read the C-fragment data) and writes it back using the hand-constructed layout. After the round-trip, the data is in the hand-constructed layout, which is close to (but not identical to) the `get_tmem_load_op` layout → 3% error (cos 0.973). + +### The Proper Fix: correction_epilog Pattern + +The CUTLASS FMHA reference uses a one-way trip for the final epilogue: + +``` +TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM ``` -**Results after TMA fix (verified on B200, May 22):** -- n=128: cos 0.999998 ✅ -- n=256: cos 0.71 (TMA loads 2 tiles correctly, needs O rescale for 0.9999) -- n=512/1024: output identical to n=256 — pipeline not cycling past kv_stage=2 +This reads O using `get_tmem_load_op` (same layout as `epilogue_tma_store`) and writes directly to SMEM. No TMEM round-trip. No layout mismatch. The CUTLASS reference uses this pattern and gets correct results. + +**Why we can't use it yet:** The TMA store from SMEM → GMEM requires `tma_partition` / `flat_divide` which hit CuTeDSL region isolation errors when called inside `if warp_idx < self.mma_warp_id` blocks. The `epilogue_tma_store` helper works because it's a regular Python function that inlines into the same MLIR region — but it always reads from TMEM, not SMEM. + +**Possible solutions:** +1. Call `epilogue_tma_store` but inject the `1/row_sum` multiply into its pipeline (requires modifying the helper or replicating it inline with the scale) +2. Pre-compute TMA partitioning outside the `if` block and pass the partitioned tensors through the kernel interface +3. Use the experimental `cutlass.cute.experimental.epilogue_tma_store` API which has a cleaner structure + +### Verified Facts + +- **Raw PV output is perfect:** `epilogue_tma_store` with identity op gives cos 0.999998 at n=128 +- **Softmax P values are correct:** Unnormalized P@V matches reference exactly (cos 0.999998) +- **Online softmax computation is correct:** row_max and row_sum tracking works +- **The ONLY issue is the TMEM round-trip for normalize/rescale** +- **Stage A/B with identity softmax: cos 0.999999** — the pipeline works, softmax is the only addition + +### Architecture (6-warp, current) -**Verified tensor shapes (diag prints inside @cute.kernel on B200, n=256):** ``` -Before (None,0,None,0) pre-slice: - tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes - tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes - tVgV: (((64,128),1), 1, 1, 1) — 4 modes - -After (None,0,None,0) pre-slice: - tAgQ: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles - tBgK: (((64,128),1), Int32(?)) — 2 modes, mode 1 = KV tiles - tVgV: (((64,128),1), 1) — 2 modes, mode 1 = 1 (static) -``` - -### Remaining for Multi-Tile - -1. O rescale between tiles: `O *= exp2(old_max - new_max)` — needed for n=256+ to hit 0.9999 -2. Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages) — output identical for all n>256, meaning only 2 KV tiles are loaded -3. Correction warps for production (separate softmax/correction/epilogue) -4. 12-warp layout - -### Files - -| File | Status | Notes | -|------|--------|-------| -| `fmha_v3_stage_c_example10.py` | 🔨 CURRENT | (None,0,None,0) TMA, combined K+V pipeline, O rescale, final normalize | -| `test_fmha_v3_stage_c_full.py` | OK n=128 | Working real softmax + O normalization | -| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 | -| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks | -| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks | -| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline | - -### Current Architecture (6-warp) - -Warps 0-3: Softmax + Epilogue +Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize) Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load) - -### Target Architecture (12-warp, production) - -Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty - -### CuTeDSL Constraints (hard-won) - -1. `vectorize=True` loops: ONLY load/store/print -2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row -3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop -4. `tBgK`/`tVgV` have 4 modes after tma_partition — (None,0,None,0) keeps mode 2 (KV tiles) free, [None, kt] indexes it -5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0 -6. `softmax_done_bar` NamedBarrier is reusable across tiles - -### Remaining for C' (Production Stage C) - -1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2 -2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync) -3. Cross-warp reduction for row_max and row_sum -4. Correction warps for multi-tile KV (online O rescale in TMEM) -5. 12-warp layout with separate softmax/correction/epilogue warps +``` ### TMEM Layout -Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32) +``` +Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32)) +Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view) +Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16)) +``` + +### Remaining for Multi-Tile Production + +1. **Fix TMEM layout mismatch** — replace hand-constructed atom round-trips with correction_epilog pattern +2. **Pipeline state cycling for n≥384** — kv_stage=2 can only buffer 2 tiles +3. **12-warp layout** — separate softmax/correction/epilogue warps +4. **O rescale for kt > 0** — must also use paired atoms or correction_epilog + +--- + +## CuTeDSL Constraints (hard-won) + +1. **`vectorize=True` loops: ONLY load/store/print** — no fmax, no cmpf, no inner loops, no carry +2. **`.reduce(cute.ReductionOp.MAX)`:** reduces ENTIRE C-fragment to scalar — global max, not per-row +3. **`cute.arch.fmax`:** impure for vectorizer — use plain `range()` loop +4. **TMA partition tensors have 4 modes:** `(((64,128),1), ?, KV_tiles, ?)` — `(None,0,None,0)` keeps mode 2 (KV tiles) free, `[None, kt]` indexes it +5. **`tBgK[(None, None, 0, 0)]` pins mode 2 to 0** — silently reads tile 0 forever. Use `(None,0,None,0)` instead. +6. **`softmax_done_bar` NamedBarrier is reusable** across tiles +7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips. +8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions. +9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with. --- @@ -389,6 +389,8 @@ Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output. 7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM. 8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this. +9. **PRINT THE SHAPES. ALWAYS.** Reasoning about TMEM layouts without evidence is how we waste days. +10. **Never assume TMEM round-trips are safe.** Verify with NO-OP tests before adding logic. --- @@ -400,3 +402,4 @@ Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, - Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4` - vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell) - CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` +- Local CUTLASS clone: `/home/openclaw/dev/cutlass`