README + MEMORY: update Stage C status to single-tile only, document multi-tile blocker

- Stage C works for n=128 (0.993) but multi-tile (n>128) is broken
- Root cause: tBgK slice hardcodes GMEM iteration to tile 0
- CuTeDSL TMA copy doesn't accept Python int as tile index
- Mike's combined K+V barrier fix compiles but deadlocks at runtime
- Fallback: kh.count // 2 (untested)
This commit is contained in:
2026-05-22 16:32:31 +00:00
parent 3ddee26eff
commit 4fb606d41a

View File

@@ -68,14 +68,14 @@ Summary
---
## Status (May 22, 2026 — 09:40 UTC)
## Status (May 22, 2026 — 16:30 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ✅ WORKING | Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996 |
| C' | 🔨 NEXT | Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV |
| C | ⚠️ SINGLE-TILE ONLY | Real online softmax works for n=128 (cosine 0.993-0.996). **Multi-tile (n>128) broken.** |
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
@@ -156,60 +156,63 @@ dsv4/
---
## Stage C: Online Softmax — WORKING
## Stage C: Online Softmax — SINGLE-TILE ONLY
### What We Have
**Working real softmax** in `test_fmha_v3_stage_c_full.py`: cosine 0.9930.996 across 3 seeds.
**Working real softmax** for single KV tile (n=128) in `test_fmha_v3_stage_c_full.py`: cosine 0.993-0.996.
**Multi-tile (n>128) is broken** — see blocker below.
### Multi-Tile Blocker: TMA GMEM Tile Indexing
The original TMA partition slices `tBgK` with `(None, 0, None, 0)` which **hardcodes the GMEM iteration dimension to tile 0**. This means TMA always loads K/V from the first 128 tokens regardless of kt. Output is identical for all n>128.
**Why you can't just index with kt:** CuTeDSL's TMA copy API accepts pipeline state values (like `kh.count`) as TMA coordinates but does NOT accept Python int from `range()`. Indexing with kt fails at operation creation.
**Fix (Mike):** Combined K+V barrier — one `acquire_and_advance` per kt, two cute.copy calls sharing `kvh.barrier`. With no interleaving, `kvh.count` naturally equals kt and stays a first-class pipeline state value. See `fmha_v3_stage_c_example2.py`.
**Current status of fix:** Compiles but deadlocks at runtime (even n=128). The 3-way sync between `acc_pipe`, `softmax_done_bar`, and `final_o_bar` needs debugging. Fallback: `kh.count // 2` in the original interleaved kernel (CuTeDSL Int32 overloads `__floordiv__` in recent versions).
### Files
| File | Status | Notes |
|------|--------|-------|
| `test_fmha_v3_stage_c_full.py` | OK n=128 only | Working real softmax + O normalization |
| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks |
| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks |
| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline |
### Current Architecture (6-warp)
Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue
Warp 4: MMA (QK→S, PV→O)
Warps 0-3: Softmax + Epilogue
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
### Target Architecture (12-warp, production)
Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max]
Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write
Warp 8: MMA — QK→S, PV→O with pipeline chaining
Warp 9: TMA — Q/K/V load
Warp 10: Epilogue — O SMEM→GMEM via TMA
Warp 11: Empty — tmem dealloc mbar init
Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction)
Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty
### CuTeDSL Constraints (hard-won)
1. `vectorize=True` loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row. Use `cute.arch.fmax` element-wise instead
3. Dynamic control flow: variables need initial values BEFORE the flow starts
4. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
5. Carry variables (row_max, row_sum): cannot use `vectorize=True`
1. `vectorize=True` loops: ONLY load/store/print
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
4. TMA cute.copy accepts pipeline state values as coordinates but NOT Python int
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
6. `softmax_done_bar` NamedBarrier is reusable across tiles
### Remaining for C' (Production Stage C)
1. Cross-warp reduction for row_max and row_sum
2. Correction warps for multi-tile KV (online O rescale in TMEM)
3. 12-warp layout with separate softmax/correction/epilogue warps
4. Per-row O normalization
1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
3. Cross-warp reduction for row_max and row_sum
4. Correction warps for multi-tile KV (online O rescale in TMEM)
5. 12-warp layout with separate softmax/correction/epilogue warps
### TMEM Layout
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32)
Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer.
---
## Stage E: Production Kernel Extraction
When ready, extract from `test_fmha_v3.py``dsv4/kernels/attention/fmha.py`:
1. Clean `FmhaKernel` class with `@cute.jit __call__`, no hardcoded dimensions
2. Add real softmax (Stage C)
3. Add paged KV cache (Stage D)
5. Wrap as `torch.library.custom_op` in `dsv4/ops/`
6. Integrate with vLLM
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)
---