Milestone 5 COMPLETE: multi-head FMHA grid launch verified on B200

All HD=16/64/128/256 pass across MHA (4+8 heads), MQA, batched modes.
cos 0.999997+, LSE matches reference. Updated CURRENT_ISSUE.md.
This commit is contained in:
2026-05-28 19:35:06 +00:00
parent 3fd302e7a0
commit adc88613fa

View File

@@ -1,12 +1,17 @@
# CURRENT_ISSUE.md — FMHA 6-Warp Specialization
## Status: Milestone 1 COMPLETE ✅ (cos 0.999997+ at HD=16/64/128/256)
## Status: Milestone 5 COMPLETE ✅ (multi-head grid launch with MHA/MQA/batch)
### What works:
- **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging
- **All HD values**: HD=16/64/128/256 pass with cos 0.999997+
- **Warp role separation**: MMA and data loading on separate warps
- **CTA-wide sync**: __syncthreads() between phases
- **Multi-head grid launch**: grid=(1, n_h, batch), each CTA handles one head
- **MQA**: k_head_stride=0 / v_head_stride=0 for shared KV heads
- **Batched**: blockIdx.z for batch dimension
- **LSE output**: per-row LSE for multi-segment KV merge
- **FmhaParams struct**: stride-based tensor addressing, future-proof for GQA
### Architecture:
```
@@ -26,17 +31,23 @@ Warp 5 (tid 160-191): Data staging
1. **TMA loads** (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor
- Requires CUtensorMap creation on host
- mbarrier synchronization
- BLOCKED: cuTensorMapEncodeTiled 2D/3D/5D returns INVALID_VALUE on B200 driver v580.126.20
- Alternative: Study CuTeDSL's TMA descriptor creation source code
2. **Pipeline overlap** (Milestone 3): Double-buffer K/V loads
- Load next K/V while computing current QK
- mbarrier producer-consumer sync between warp 5 and warp 4
- Depends on TMA loads (Milestone 2)
3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1)
4. **Multi-head launch** (Milestone 5): grid=(1, n_h, batch)
- All 4 softmax warps process rows in parallel
- Warp w handles rows [w*32, (w+1)*32) ∩ [0, T)
4. ~~**Multi-head launch** (Milestone 5): grid=(1, n_h, batch)~~ ✅ DONE
5. **Production integration** (Milestone 6): Hook into production.py
### Files:
- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel
- `tests/unit/test_fmha_6warp.cu`Test harness
- `tests/unit/test_fmha_6warp_hd{16,64,128,256}.cu` — HD-specific wrappers
- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel (single-head)
- `dsv4/kernels/attention/fmha_6warp_multihead.cuh`Multi-head grid launch kernel
- `tests/unit/test_fmha_6warp_multihead.cu` — Multi-head test harness
- `tests/unit/test_fmha_6warp_multihead_hd{16,64,128,256}.cu` — HD-specific wrappers
### Layout D N=64 Bug (documented for NVIDIA):
- tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51