README: update Stage C status to WORKING, add CuTeDSL constraints and target architecture
This commit is contained in:
74
README.md
74
README.md
@@ -68,13 +68,14 @@ Summary
|
||||
|
||||
---
|
||||
|
||||
## Status (May 21, 2026 — 17:30 UTC)
|
||||
## Status (May 22, 2026 — 09:40 UTC)
|
||||
|
||||
| Stage | Status | Description |
|
||||
|-------|--------|-------------|
|
||||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||||
| C | 🔨 IN PROGRESS | Real softmax: row max, exp, rescale, row sum (kernel written, needs test harness) |
|
||||
| C | ✅ WORKING | Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996 |
|
||||
| C' | 🔨 NEXT | Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV |
|
||||
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
|
||||
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
|
||||
|
||||
@@ -126,8 +127,10 @@ dsv4/
|
||||
|
||||
| File | Stage | Status |
|
||||
|------|-------|--------|
|
||||
| `test_fmha_v3.py` | A+B | ✅ Full QK→softmax→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_softmax.py` | C | 🔨 Online softmax kernel (needs test harness) |
|
||||
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_stage_c_full.py` | C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 |
|
||||
| `test_fmha_v3_stage_c_min.py` | C | 🔨 Early 12-warp pipeline (broken pipeline state) |
|
||||
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
|
||||
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
|
||||
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
|
||||
@@ -153,48 +156,49 @@ dsv4/
|
||||
|
||||
---
|
||||
|
||||
## Stage C: Online Softmax
|
||||
## Stage C: Online Softmax — WORKING
|
||||
|
||||
### What We Have
|
||||
|
||||
Identity softmax in `test_fmha_v3.py`: load S FP32 → convert BF16 → store P. Proves TMEM pipeline works.
|
||||
**Working real softmax** in `test_fmha_v3_stage_c_full.py`: cosine 0.993–0.996 across 3 seeds.
|
||||
|
||||
### What We Are Building
|
||||
### Current Architecture (6-warp)
|
||||
|
||||
Online softmax in `test_fmha_v3_softmax.py` (kernel written, no test runner yet):
|
||||
Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue
|
||||
Warp 4: MMA (QK→S, PV→O)
|
||||
Warp 5: TMA (Q/K/V load)
|
||||
|
||||
```
|
||||
For each KV tile:
|
||||
1. QK → S (FP32 in TMEM)
|
||||
2. tile_max = max(S[j,:])
|
||||
3. new_max = max(old_max, tile_max)
|
||||
4. O *= exp(old_max - new_max) ← TMEM rescale
|
||||
5. P = exp2((S - new_max) * scale) ← exp2 with 1/sqrt(d) * log2(e)
|
||||
6. Store P to TMEM (FMHA pattern)
|
||||
7. row_sum = row_sum * exp(old_max - new_max) + sum(P)
|
||||
8. PV: O += P @ V
|
||||
After all tiles:
|
||||
9. O /= row_sum ← final TMEM normalization
|
||||
```
|
||||
### Target Architecture (12-warp, production)
|
||||
|
||||
### Key Implementation Details
|
||||
Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max]
|
||||
Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write
|
||||
Warp 8: MMA — QK→S, PV→O with pipeline chaining
|
||||
Warp 9: TMA — Q/K/V load
|
||||
Warp 10: Epilogue — O SMEM→GMEM via TMA
|
||||
Warp 11: Empty — tmem dealloc mbar init
|
||||
|
||||
- **Row max:** `tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)` per tile
|
||||
- **O rescale:** Load O from TMEM, multiply by `exp2(old_max - new_max)`, store back (16-col tiles via `Ld32x32b/St32x32b`)
|
||||
- **P computation:** `exp2((S - row_max) * scale)` where `scale = 1/sqrt(HEAD_DIM) * log2(e)`
|
||||
- **Row sum:** Packed `f32x2` reduction using `cute.arch.add_packed_f32x2` (4 unroll, 2-wide)
|
||||
- **Final norm:** Load O, multiply by `1/row_sum`, store (same TMEM load/store path)
|
||||
Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction)
|
||||
|
||||
### TMEM Layout (Current — Stage B)
|
||||
### CuTeDSL Constraints (hard-won)
|
||||
|
||||
```
|
||||
Col: 0 32 64 96 128 192 256
|
||||
|---- S ----|---- P ----| |---- O ----|
|
||||
| QK acc | Softmax P | (gap) | PV acc |
|
||||
| 128 FP32 | 64 FP32 | 32 col | 64 FP32 |
|
||||
```
|
||||
1. `vectorize=True` loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry
|
||||
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row. Use `cute.arch.fmax` element-wise instead
|
||||
3. Dynamic control flow: variables need initial values BEFORE the flow starts
|
||||
4. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
|
||||
5. Carry variables (row_max, row_sum): cannot use `vectorize=True`
|
||||
|
||||
For Stage C, row_max/row_sum are per-thread FP32 scalars (not in TMEM). Future stages may need TMEM-backed state for wider tiles.
|
||||
### Remaining for C' (Production Stage C)
|
||||
|
||||
1. Cross-warp reduction for row_max and row_sum
|
||||
2. Correction warps for multi-tile KV (online O rescale in TMEM)
|
||||
3. 12-warp layout with separate softmax/correction/epilogue warps
|
||||
4. Per-row O normalization
|
||||
|
||||
### TMEM Layout
|
||||
|
||||
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32)
|
||||
|
||||
Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer.
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user