diff --git a/README.md b/README.md index b0a335be..63acbe7b 100644 --- a/README.md +++ b/README.md @@ -68,13 +68,14 @@ Summary --- -## Status (May 21, 2026 — 17:30 UTC) +## Status (May 22, 2026 — 09:40 UTC) | Stage | Status | Description | |-------|--------|-------------| | A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM | | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | -| C | 🔨 IN PROGRESS | Real softmax: row max, exp, rescale, row sum (kernel written, needs test harness) | +| C | ✅ WORKING | Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996 | +| C' | 🔨 NEXT | Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV | | D | TODO | Full decode attention: paged KV cache, multi-query, causal mask | | E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge | @@ -126,8 +127,10 @@ dsv4/ | File | Stage | Status | |------|-------|--------| -| `test_fmha_v3.py` | A+B | ✅ Full QK→softmax→PV, cosine 0.999999 | -| `test_fmha_v3_softmax.py` | C | 🔨 Online softmax kernel (needs test harness) | +| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 | +| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 | +| `test_fmha_v3_stage_c_full.py` | C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 | +| `test_fmha_v3_stage_c_min.py` | C | 🔨 Early 12-warp pipeline (broken pipeline state) | | `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline | | `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline | | `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines | @@ -153,48 +156,49 @@ dsv4/ --- -## Stage C: Online Softmax +## Stage C: Online Softmax — WORKING ### What We Have -Identity softmax in `test_fmha_v3.py`: load S FP32 → convert BF16 → store P. Proves TMEM pipeline works. +**Working real softmax** in `test_fmha_v3_stage_c_full.py`: cosine 0.993–0.996 across 3 seeds. -### What We Are Building +### Current Architecture (6-warp) -Online softmax in `test_fmha_v3_softmax.py` (kernel written, no test runner yet): +Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue +Warp 4: MMA (QK→S, PV→O) +Warp 5: TMA (Q/K/V load) -``` -For each KV tile: - 1. QK → S (FP32 in TMEM) - 2. tile_max = max(S[j,:]) - 3. new_max = max(old_max, tile_max) - 4. O *= exp(old_max - new_max) ← TMEM rescale - 5. P = exp2((S - new_max) * scale) ← exp2 with 1/sqrt(d) * log2(e) - 6. Store P to TMEM (FMHA pattern) - 7. row_sum = row_sum * exp(old_max - new_max) + sum(P) - 8. PV: O += P @ V -After all tiles: - 9. O /= row_sum ← final TMEM normalization -``` +### Target Architecture (12-warp, production) -### Key Implementation Details +Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max] +Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write +Warp 8: MMA — QK→S, PV→O with pipeline chaining +Warp 9: TMA — Q/K/V load +Warp 10: Epilogue — O SMEM→GMEM via TMA +Warp 11: Empty — tmem dealloc mbar init -- **Row max:** `tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)` per tile -- **O rescale:** Load O from TMEM, multiply by `exp2(old_max - new_max)`, store back (16-col tiles via `Ld32x32b/St32x32b`) -- **P computation:** `exp2((S - row_max) * scale)` where `scale = 1/sqrt(HEAD_DIM) * log2(e)` -- **Row sum:** Packed `f32x2` reduction using `cute.arch.add_packed_f32x2` (4 unroll, 2-wide) -- **Final norm:** Load O, multiply by `1/row_sum`, store (same TMEM load/store path) +Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction) -### TMEM Layout (Current — Stage B) +### CuTeDSL Constraints (hard-won) -``` -Col: 0 32 64 96 128 192 256 - |---- S ----|---- P ----| |---- O ----| - | QK acc | Softmax P | (gap) | PV acc | - | 128 FP32 | 64 FP32 | 32 col | 64 FP32 | -``` +1. `vectorize=True` loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry +2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row. Use `cute.arch.fmax` element-wise instead +3. Dynamic control flow: variables need initial values BEFORE the flow starts +4. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop +5. Carry variables (row_max, row_sum): cannot use `vectorize=True` -For Stage C, row_max/row_sum are per-thread FP32 scalars (not in TMEM). Future stages may need TMEM-backed state for wider tiles. +### Remaining for C' (Production Stage C) + +1. Cross-warp reduction for row_max and row_sum +2. Correction warps for multi-tile KV (online O rescale in TMEM) +3. 12-warp layout with separate softmax/correction/epilogue warps +4. Per-row O normalization + +### TMEM Layout + +Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32) + +Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer. ---