README + MEMORY: update Stage C status to single-tile only, document multi-tile blocker
- Stage C works for n=128 (0.993) but multi-tile (n>128) is broken - Root cause: tBgK slice hardcodes GMEM iteration to tile 0 - CuTeDSL TMA copy doesn't accept Python int as tile index - Mike's combined K+V barrier fix compiles but deadlocks at runtime - Fallback: kh.count // 2 (untested)
This commit is contained in:
79
README.md
79
README.md
@@ -68,14 +68,14 @@ Summary
|
||||
|
||||
---
|
||||
|
||||
## Status (May 22, 2026 — 09:40 UTC)
|
||||
## Status (May 22, 2026 — 16:30 UTC)
|
||||
|
||||
| Stage | Status | Description |
|
||||
|-------|--------|-------------|
|
||||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||||
| C | ✅ WORKING | Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996 |
|
||||
| C' | 🔨 NEXT | Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV |
|
||||
| C | ⚠️ SINGLE-TILE ONLY | Real online softmax works for n=128 (cosine 0.993-0.996). **Multi-tile (n>128) broken.** |
|
||||
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
|
||||
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
|
||||
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
|
||||
|
||||
@@ -156,60 +156,63 @@ dsv4/
|
||||
|
||||
---
|
||||
|
||||
## Stage C: Online Softmax — WORKING
|
||||
## Stage C: Online Softmax — SINGLE-TILE ONLY
|
||||
|
||||
### What We Have
|
||||
|
||||
**Working real softmax** in `test_fmha_v3_stage_c_full.py`: cosine 0.993–0.996 across 3 seeds.
|
||||
**Working real softmax** for single KV tile (n=128) in `test_fmha_v3_stage_c_full.py`: cosine 0.993-0.996.
|
||||
**Multi-tile (n>128) is broken** — see blocker below.
|
||||
|
||||
### Multi-Tile Blocker: TMA GMEM Tile Indexing
|
||||
|
||||
The original TMA partition slices `tBgK` with `(None, 0, None, 0)` which **hardcodes the GMEM iteration dimension to tile 0**. This means TMA always loads K/V from the first 128 tokens regardless of kt. Output is identical for all n>128.
|
||||
|
||||
**Why you can't just index with kt:** CuTeDSL's TMA copy API accepts pipeline state values (like `kh.count`) as TMA coordinates but does NOT accept Python int from `range()`. Indexing with kt fails at operation creation.
|
||||
|
||||
**Fix (Mike):** Combined K+V barrier — one `acquire_and_advance` per kt, two cute.copy calls sharing `kvh.barrier`. With no interleaving, `kvh.count` naturally equals kt and stays a first-class pipeline state value. See `fmha_v3_stage_c_example2.py`.
|
||||
|
||||
**Current status of fix:** Compiles but deadlocks at runtime (even n=128). The 3-way sync between `acc_pipe`, `softmax_done_bar`, and `final_o_bar` needs debugging. Fallback: `kh.count // 2` in the original interleaved kernel (CuTeDSL Int32 overloads `__floordiv__` in recent versions).
|
||||
|
||||
### Files
|
||||
|
||||
| File | Status | Notes |
|
||||
|------|--------|-------|
|
||||
| `test_fmha_v3_stage_c_full.py` | OK n=128 only | Working real softmax + O normalization |
|
||||
| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
|
||||
| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks |
|
||||
| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks |
|
||||
| `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline |
|
||||
|
||||
### Current Architecture (6-warp)
|
||||
|
||||
Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue
|
||||
Warp 4: MMA (QK→S, PV→O)
|
||||
Warps 0-3: Softmax + Epilogue
|
||||
Warp 4: MMA (QK, PV)
|
||||
Warp 5: TMA (Q/K/V load)
|
||||
|
||||
### Target Architecture (12-warp, production)
|
||||
|
||||
Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max]
|
||||
Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write
|
||||
Warp 8: MMA — QK→S, PV→O with pipeline chaining
|
||||
Warp 9: TMA — Q/K/V load
|
||||
Warp 10: Epilogue — O SMEM→GMEM via TMA
|
||||
Warp 11: Empty — tmem dealloc mbar init
|
||||
|
||||
Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction)
|
||||
Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty
|
||||
|
||||
### CuTeDSL Constraints (hard-won)
|
||||
|
||||
1. `vectorize=True` loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry
|
||||
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row. Use `cute.arch.fmax` element-wise instead
|
||||
3. Dynamic control flow: variables need initial values BEFORE the flow starts
|
||||
4. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
|
||||
5. Carry variables (row_max, row_sum): cannot use `vectorize=True`
|
||||
1. `vectorize=True` loops: ONLY load/store/print
|
||||
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
|
||||
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
|
||||
4. TMA cute.copy accepts pipeline state values as coordinates but NOT Python int
|
||||
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
|
||||
6. `softmax_done_bar` NamedBarrier is reusable across tiles
|
||||
|
||||
### Remaining for C' (Production Stage C)
|
||||
|
||||
1. Cross-warp reduction for row_max and row_sum
|
||||
2. Correction warps for multi-tile KV (online O rescale in TMEM)
|
||||
3. 12-warp layout with separate softmax/correction/epilogue warps
|
||||
4. Per-row O normalization
|
||||
1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
|
||||
2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
|
||||
3. Cross-warp reduction for row_max and row_sum
|
||||
4. Correction warps for multi-tile KV (online O rescale in TMEM)
|
||||
5. 12-warp layout with separate softmax/correction/epilogue warps
|
||||
|
||||
### TMEM Layout
|
||||
|
||||
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32)
|
||||
|
||||
Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer.
|
||||
|
||||
---
|
||||
|
||||
## Stage E: Production Kernel Extraction
|
||||
|
||||
When ready, extract from `test_fmha_v3.py` → `dsv4/kernels/attention/fmha.py`:
|
||||
1. Clean `FmhaKernel` class with `@cute.jit __call__`, no hardcoded dimensions
|
||||
2. Add real softmax (Stage C)
|
||||
3. Add paged KV cache (Stage D)
|
||||
5. Wrap as `torch.library.custom_op` in `dsv4/ops/`
|
||||
6. Integrate with vLLM
|
||||
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user