diff --git a/CURRENT_ISSUE.md b/CURRENT_ISSUE.md index d873182f..bf6e164e 100644 --- a/CURRENT_ISSUE.md +++ b/CURRENT_ISSUE.md @@ -1,36 +1,43 @@ -# CURRENT_ISSUE.md — FMHA Raw CUDA Stage D +# CURRENT_ISSUE.md — FMHA 6-Warp Specialization -## Status: HD=16/64/128/256 ALL PASS ✅ (cos 0.999997+) +## Status: Milestone 1 COMPLETE ✅ (cos 0.999997+ at HD=16/64/128/256) ### What works: -- **HD=16**: Full pipeline QK(SS) → softmax → PV(SS, N=16) → epilogue. **Cosine 0.999998**. -- **HD=64**: Full pipeline with 4 N=16 PV sub-tiles. **Cosine 0.999997**. -- **HD=128**: Full pipeline with 8 N=16 PV sub-tiles. **Cosine 0.999997**. -- **HD=256**: Full pipeline with 16 N=16 PV sub-tiles. **Cosine 0.999997**. +- **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging +- **All HD values**: HD=16/64/128/256 pass with cos 0.999997+ +- **Warp role separation**: MMA and data loading on separate warps +- **CTA-wide sync**: __syncthreads() between phases -### Key architectural decisions: -1. **PV via SS MMA with N=16 sub-tiles** (NOT N=64 or N=128). The `tcgen05.mma` with `make_idesc(128, N)` for N≠16,128 has a Layout D bug that skips TMEM columns (8 missing columns at N=64: cols 32-35 and 48-51). -2. **Q/K loaded one K-tile at a time** (reusing single SMEM buffer). This keeps SMEM at ~25KB regardless of HD. -3. **TMEM offset `tb + n*16`** for each PV N-sub-tile. The MMA with N=16 writes 16 columns starting at the C operand address. -4. **MMA scale is 1.0** for both QK and PV (NOT 0.5 as initially assumed). +### Architecture: +``` +Warp 0-3 (tid 0-127): Softmax + correction + epilogue + - Read S from TMEM → softmax → write P to SMEM + - After PV: read O from TMEM → BF16 → GMEM + - T=1 decode: only warp 0 processes row 0 +Warp 4 (tid 128-159): MMA + - tcgen05.mma SS for QK (N=128) and PV (N=16 sub-tiles) + - TMEM alloc/dealloc +Warp 5 (tid 160-191): Data staging + - Load Q/K/V from GMEM to SMEM (canonical layout) + - Fill sPk from s_p_vals +``` + +### Next milestones: +1. **TMA loads** (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor + - Requires CUtensorMap creation on host + - mbarrier synchronization +2. **Pipeline overlap** (Milestone 3): Double-buffer K/V loads + - Load next K/V while computing current QK + - mbarrier producer-consumer sync between warp 5 and warp 4 +3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1) +4. **Multi-head launch** (Milestone 5): grid=(1, n_h, batch) +5. **Production integration** (Milestone 6): Hook into production.py ### Files: -- `tests/unit/test_fmha_gen.cu` — Generalized kernel + test (parameterized on HD_VAL) -- `tests/unit/test_fmha_hd{16,64,128,256}_gen.cu` — Wrapper files -- `tests/unit/test_fmha_hd64_n16_v2.cu` — Standalone HD=64 test (first to pass) -- `tests/unit/test_tmem_zero_pv.cu` — TMEM Layout D N=64 bug proof -- `tests/unit/test_tmem_all_lanes.cu` — All-lane TMEM dump for N=64 - -### Next steps: -1. **Production kernel**: Extract the generalized kernel into `fmha_sm100_tc.cuh` with proper 6-warp specialization -2. **Prefill T>1**: Fill all 128 rows of sPk (not just row 0), enable multi-row softmax -3. **Multi-head**: Per-head launch or head-packed M dimension -4. **CUDA graph**: Integration with production.py -5. **Benchmarking**: Measure actual throughput vs CuTeDSL kernel +- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel +- `tests/unit/test_fmha_6warp.cu` — Test harness +- `tests/unit/test_fmha_6warp_hd{16,64,128,256}.cu` — HD-specific wrappers ### Layout D N=64 Bug (documented for NVIDIA): -- `tcgen05.mma.cta_group::1.kind::f16` with `make_idesc(128, 64)` writes to only 56 out of 64 expected TMEM columns -- Missing columns: 32, 33, 34, 35, 48, 49, 50, 51 -- These columns contain zero after the MMA, causing output positions d=32-35 and d=48-51 to be wrong -- Workaround: use N=16 sub-tiles (4 × make_idesc(128, 16)) with different TMEM offsets -- This bug does NOT affect N=16 or N=128 — only intermediate N values +- tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51 +- Workaround: N=16 sub-tiles with TMEM offset n*16