Update CURRENT_ISSUE: 6-warp Milestone 1 complete
This commit is contained in:
@@ -1,36 +1,43 @@
|
||||
# CURRENT_ISSUE.md — FMHA Raw CUDA Stage D
|
||||
# CURRENT_ISSUE.md — FMHA 6-Warp Specialization
|
||||
|
||||
## Status: HD=16/64/128/256 ALL PASS ✅ (cos 0.999997+)
|
||||
## Status: Milestone 1 COMPLETE ✅ (cos 0.999997+ at HD=16/64/128/256)
|
||||
|
||||
### What works:
|
||||
- **HD=16**: Full pipeline QK(SS) → softmax → PV(SS, N=16) → epilogue. **Cosine 0.999998**.
|
||||
- **HD=64**: Full pipeline with 4 N=16 PV sub-tiles. **Cosine 0.999997**.
|
||||
- **HD=128**: Full pipeline with 8 N=16 PV sub-tiles. **Cosine 0.999997**.
|
||||
- **HD=256**: Full pipeline with 16 N=16 PV sub-tiles. **Cosine 0.999997**.
|
||||
- **6-warp kernel**: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging
|
||||
- **All HD values**: HD=16/64/128/256 pass with cos 0.999997+
|
||||
- **Warp role separation**: MMA and data loading on separate warps
|
||||
- **CTA-wide sync**: __syncthreads() between phases
|
||||
|
||||
### Key architectural decisions:
|
||||
1. **PV via SS MMA with N=16 sub-tiles** (NOT N=64 or N=128). The `tcgen05.mma` with `make_idesc(128, N)` for N≠16,128 has a Layout D bug that skips TMEM columns (8 missing columns at N=64: cols 32-35 and 48-51).
|
||||
2. **Q/K loaded one K-tile at a time** (reusing single SMEM buffer). This keeps SMEM at ~25KB regardless of HD.
|
||||
3. **TMEM offset `tb + n*16`** for each PV N-sub-tile. The MMA with N=16 writes 16 columns starting at the C operand address.
|
||||
4. **MMA scale is 1.0** for both QK and PV (NOT 0.5 as initially assumed).
|
||||
### Architecture:
|
||||
```
|
||||
Warp 0-3 (tid 0-127): Softmax + correction + epilogue
|
||||
- Read S from TMEM → softmax → write P to SMEM
|
||||
- After PV: read O from TMEM → BF16 → GMEM
|
||||
- T=1 decode: only warp 0 processes row 0
|
||||
Warp 4 (tid 128-159): MMA
|
||||
- tcgen05.mma SS for QK (N=128) and PV (N=16 sub-tiles)
|
||||
- TMEM alloc/dealloc
|
||||
Warp 5 (tid 160-191): Data staging
|
||||
- Load Q/K/V from GMEM to SMEM (canonical layout)
|
||||
- Fill sPk from s_p_vals
|
||||
```
|
||||
|
||||
### Next milestones:
|
||||
1. **TMA loads** (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor
|
||||
- Requires CUtensorMap creation on host
|
||||
- mbarrier synchronization
|
||||
2. **Pipeline overlap** (Milestone 3): Double-buffer K/V loads
|
||||
- Load next K/V while computing current QK
|
||||
- mbarrier producer-consumer sync between warp 5 and warp 4
|
||||
3. **Multi-row softmax** (Milestone 4): Process all 128 rows (prefill T>1)
|
||||
4. **Multi-head launch** (Milestone 5): grid=(1, n_h, batch)
|
||||
5. **Production integration** (Milestone 6): Hook into production.py
|
||||
|
||||
### Files:
|
||||
- `tests/unit/test_fmha_gen.cu` — Generalized kernel + test (parameterized on HD_VAL)
|
||||
- `tests/unit/test_fmha_hd{16,64,128,256}_gen.cu` — Wrapper files
|
||||
- `tests/unit/test_fmha_hd64_n16_v2.cu` — Standalone HD=64 test (first to pass)
|
||||
- `tests/unit/test_tmem_zero_pv.cu` — TMEM Layout D N=64 bug proof
|
||||
- `tests/unit/test_tmem_all_lanes.cu` — All-lane TMEM dump for N=64
|
||||
|
||||
### Next steps:
|
||||
1. **Production kernel**: Extract the generalized kernel into `fmha_sm100_tc.cuh` with proper 6-warp specialization
|
||||
2. **Prefill T>1**: Fill all 128 rows of sPk (not just row 0), enable multi-row softmax
|
||||
3. **Multi-head**: Per-head launch or head-packed M dimension
|
||||
4. **CUDA graph**: Integration with production.py
|
||||
5. **Benchmarking**: Measure actual throughput vs CuTeDSL kernel
|
||||
- `dsv4/kernels/attention/fmha_6warp.cuh` — 6-warp kernel
|
||||
- `tests/unit/test_fmha_6warp.cu` — Test harness
|
||||
- `tests/unit/test_fmha_6warp_hd{16,64,128,256}.cu` — HD-specific wrappers
|
||||
|
||||
### Layout D N=64 Bug (documented for NVIDIA):
|
||||
- `tcgen05.mma.cta_group::1.kind::f16` with `make_idesc(128, 64)` writes to only 56 out of 64 expected TMEM columns
|
||||
- Missing columns: 32, 33, 34, 35, 48, 49, 50, 51
|
||||
- These columns contain zero after the MMA, causing output positions d=32-35 and d=48-51 to be wrong
|
||||
- Workaround: use N=16 sub-tiles (4 × make_idesc(128, 16)) with different TMEM offsets
|
||||
- This bug does NOT affect N=16 or N=128 — only intermediate N values
|
||||
- tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51
|
||||
- Workaround: N=16 sub-tiles with TMEM offset n*16
|
||||
|
||||
Reference in New Issue
Block a user