diff --git a/README.md b/README.md index 63acbe7b..c5552d9f 100644 --- a/README.md +++ b/README.md @@ -68,14 +68,14 @@ Summary --- -## Status (May 22, 2026 — 09:40 UTC) +## Status (May 22, 2026 — 16:30 UTC) | Stage | Status | Description | |-------|--------|-------------| | A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM | | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | -| C | ✅ WORKING | Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996 | -| C' | 🔨 NEXT | Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV | +| C | ⚠️ SINGLE-TILE ONLY | Real online softmax works for n=128 (cosine 0.993-0.996). **Multi-tile (n>128) broken.** | +| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. | | D | TODO | Full decode attention: paged KV cache, multi-query, causal mask | | E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge | @@ -156,60 +156,63 @@ dsv4/ --- -## Stage C: Online Softmax — WORKING +## Stage C: Online Softmax — SINGLE-TILE ONLY ### What We Have -**Working real softmax** in `test_fmha_v3_stage_c_full.py`: cosine 0.993–0.996 across 3 seeds. +**Working real softmax** for single KV tile (n=128) in `test_fmha_v3_stage_c_full.py`: cosine 0.993-0.996. +**Multi-tile (n>128) is broken** — see blocker below. + +### Multi-Tile Blocker: TMA GMEM Tile Indexing + +The original TMA partition slices `tBgK` with `(None, 0, None, 0)` which **hardcodes the GMEM iteration dimension to tile 0**. This means TMA always loads K/V from the first 128 tokens regardless of kt. Output is identical for all n>128. + +**Why you can't just index with kt:** CuTeDSL's TMA copy API accepts pipeline state values (like `kh.count`) as TMA coordinates but does NOT accept Python int from `range()`. Indexing with kt fails at operation creation. + +**Fix (Mike):** Combined K+V barrier — one `acquire_and_advance` per kt, two cute.copy calls sharing `kvh.barrier`. With no interleaving, `kvh.count` naturally equals kt and stays a first-class pipeline state value. See `fmha_v3_stage_c_example2.py`. + +**Current status of fix:** Compiles but deadlocks at runtime (even n=128). The 3-way sync between `acc_pipe`, `softmax_done_bar`, and `final_o_bar` needs debugging. Fallback: `kh.count // 2` in the original interleaved kernel (CuTeDSL Int32 overloads `__floordiv__` in recent versions). + +### Files + +| File | Status | Notes | +|------|--------|-------| +| `test_fmha_v3_stage_c_full.py` | OK n=128 only | Working real softmax + O normalization | +| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 | +| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks | +| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks | +| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline | ### Current Architecture (6-warp) -Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue -Warp 4: MMA (QK→S, PV→O) +Warps 0-3: Softmax + Epilogue +Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load) ### Target Architecture (12-warp, production) -Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max] -Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write -Warp 8: MMA — QK→S, PV→O with pipeline chaining -Warp 9: TMA — Q/K/V load -Warp 10: Epilogue — O SMEM→GMEM via TMA -Warp 11: Empty — tmem dealloc mbar init - -Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction) +Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty ### CuTeDSL Constraints (hard-won) -1. `vectorize=True` loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry -2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row. Use `cute.arch.fmax` element-wise instead -3. Dynamic control flow: variables need initial values BEFORE the flow starts -4. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop -5. Carry variables (row_max, row_sum): cannot use `vectorize=True` +1. `vectorize=True` loops: ONLY load/store/print +2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row +3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop +4. TMA cute.copy accepts pipeline state values as coordinates but NOT Python int +5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0 +6. `softmax_done_bar` NamedBarrier is reusable across tiles ### Remaining for C' (Production Stage C) -1. Cross-warp reduction for row_max and row_sum -2. Correction warps for multi-tile KV (online O rescale in TMEM) -3. 12-warp layout with separate softmax/correction/epilogue warps -4. Per-row O normalization +1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2 +2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync) +3. Cross-warp reduction for row_max and row_sum +4. Correction warps for multi-tile KV (online O rescale in TMEM) +5. 12-warp layout with separate softmax/correction/epilogue warps ### TMEM Layout -Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32) - -Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer. - ---- - -## Stage E: Production Kernel Extraction - -When ready, extract from `test_fmha_v3.py` → `dsv4/kernels/attention/fmha.py`: -1. Clean `FmhaKernel` class with `@cute.jit __call__`, no hardcoded dimensions -2. Add real softmax (Stage C) -3. Add paged KV cache (Stage D) -5. Wrap as `torch.library.custom_op` in `dsv4/ops/` -6. Integrate with vLLM +Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32) ---