diff --git a/CURRENT_ISSUE.md b/CURRENT_ISSUE.md index efd44612..a9257af3 100644 --- a/CURRENT_ISSUE.md +++ b/CURRENT_ISSUE.md @@ -1,12 +1,17 @@ # CURRENT_ISSUE.md — FMHA 6-Warp Specialization -## Status: Milestone 1 COMPLETE ✅ (cos 0.999997+ at HD=16/64/128/256) +## Status: Milestone 5 COMPLETE ✅ (multi-head grid launch with MHA/MQA/batch) ### What works: - **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging - **All HD values**: HD=16/64/128/256 pass with cos 0.999997+ - **Warp role separation**: MMA and data loading on separate warps - **CTA-wide sync**: __syncthreads() between phases +- **Multi-head grid launch**: grid=(1, n_h, batch), each CTA handles one head +- **MQA**: k_head_stride=0 / v_head_stride=0 for shared KV heads +- **Batched**: blockIdx.z for batch dimension +- **LSE output**: per-row LSE for multi-segment KV merge +- **FmhaParams struct**: stride-based tensor addressing, future-proof for GQA ### Architecture: ``` @@ -26,17 +31,23 @@ Warp 5 (tid 160-191): Data staging 1. **TMA loads** (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor - Requires CUtensorMap creation on host - mbarrier synchronization + - BLOCKED: cuTensorMapEncodeTiled 2D/3D/5D returns INVALID_VALUE on B200 driver v580.126.20 + - Alternative: Study CuTeDSL's TMA descriptor creation source code 2. **Pipeline overlap** (Milestone 3): Double-buffer K/V loads - Load next K/V while computing current QK - mbarrier producer-consumer sync between warp 5 and warp 4 + - Depends on TMA loads (Milestone 2) 3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1) -4. **Multi-head launch** (Milestone 5): grid=(1, n_h, batch) + - All 4 softmax warps process rows in parallel + - Warp w handles rows [w*32, (w+1)*32) ∩ [0, T) +4. ~~**Multi-head launch** (Milestone 5): grid=(1, n_h, batch)~~ ✅ DONE 5. **Production integration** (Milestone 6): Hook into production.py ### Files: -- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel -- `tests/unit/test_fmha_6warp.cu` — Test harness -- `tests/unit/test_fmha_6warp_hd{16,64,128,256}.cu` — HD-specific wrappers +- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel (single-head) +- `dsv4/kernels/attention/fmha_6warp_multihead.cuh` — Multi-head grid launch kernel +- `tests/unit/test_fmha_6warp_multihead.cu` — Multi-head test harness +- `tests/unit/test_fmha_6warp_multihead_hd{16,64,128,256}.cu` — HD-specific wrappers ### Layout D N=64 Bug (documented for NVIDIA): - tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51