Milestone 5 COMPLETE: multi-head FMHA grid launch verified on B200
All HD=16/64/128/256 pass across MHA (4+8 heads), MQA, batched modes. cos 0.999997+, LSE matches reference. Updated CURRENT_ISSUE.md.
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
# CURRENT_ISSUE.md — FMHA 6-Warp Specialization
|
||||
|
||||
## Status: Milestone 1 COMPLETE ✅ (cos 0.999997+ at HD=16/64/128/256)
|
||||
## Status: Milestone 5 COMPLETE ✅ (multi-head grid launch with MHA/MQA/batch)
|
||||
|
||||
### What works:
|
||||
- **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging
|
||||
- **All HD values**: HD=16/64/128/256 pass with cos 0.999997+
|
||||
- **Warp role separation**: MMA and data loading on separate warps
|
||||
- **CTA-wide sync**: __syncthreads() between phases
|
||||
- **Multi-head grid launch**: grid=(1, n_h, batch), each CTA handles one head
|
||||
- **MQA**: k_head_stride=0 / v_head_stride=0 for shared KV heads
|
||||
- **Batched**: blockIdx.z for batch dimension
|
||||
- **LSE output**: per-row LSE for multi-segment KV merge
|
||||
- **FmhaParams struct**: stride-based tensor addressing, future-proof for GQA
|
||||
|
||||
### Architecture:
|
||||
```
|
||||
@@ -26,17 +31,23 @@ Warp 5 (tid 160-191): Data staging
|
||||
1. **TMA loads** (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor
|
||||
- Requires CUtensorMap creation on host
|
||||
- mbarrier synchronization
|
||||
- BLOCKED: cuTensorMapEncodeTiled 2D/3D/5D returns INVALID_VALUE on B200 driver v580.126.20
|
||||
- Alternative: Study CuTeDSL's TMA descriptor creation source code
|
||||
2. **Pipeline overlap** (Milestone 3): Double-buffer K/V loads
|
||||
- Load next K/V while computing current QK
|
||||
- mbarrier producer-consumer sync between warp 5 and warp 4
|
||||
- Depends on TMA loads (Milestone 2)
|
||||
3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1)
|
||||
4. **Multi-head launch** (Milestone 5): grid=(1, n_h, batch)
|
||||
- All 4 softmax warps process rows in parallel
|
||||
- Warp w handles rows [w*32, (w+1)*32) ∩ [0, T)
|
||||
4. ~~**Multi-head launch** (Milestone 5): grid=(1, n_h, batch)~~ ✅ DONE
|
||||
5. **Production integration** (Milestone 6): Hook into production.py
|
||||
|
||||
### Files:
|
||||
- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel
|
||||
- `tests/unit/test_fmha_6warp.cu` — Test harness
|
||||
- `tests/unit/test_fmha_6warp_hd{16,64,128,256}.cu` — HD-specific wrappers
|
||||
- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel (single-head)
|
||||
- `dsv4/kernels/attention/fmha_6warp_multihead.cuh` — Multi-head grid launch kernel
|
||||
- `tests/unit/test_fmha_6warp_multihead.cu` — Multi-head test harness
|
||||
- `tests/unit/test_fmha_6warp_multihead_hd{16,64,128,256}.cu` — HD-specific wrappers
|
||||
|
||||
### Layout D N=64 Bug (documented for NVIDIA):
|
||||
- tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51
|
||||
|
||||
Reference in New Issue
Block a user