4.5 KiB
4.5 KiB
CURRENT_ISSUE.md — FMHA 6-Warp Specialization
Status: Milestone 4 IN PROGRESS (multi-row softmax for prefill T>1)
Milestone 5 ✅ DONE — multi-head grid launch
Milestone 4: T≤32 PASSING (cos 0.999996+), T>32 BLOCKED on TMEM row read
CRITICAL BUG FIXED: Q/K SMEM canonical layout used full_d instead of local d (0..15)
What works:
- 6-warp kernel: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging
- All HD values: HD=16/64/128/256 pass with cos 0.999997+
- Warp role separation: MMA and data loading on separate warps
- CTA-wide sync: __syncthreads() between phases
- Multi-head grid launch: grid=(1, n_h, batch), each CTA handles one head
- MQA: k_head_stride=0 / v_head_stride=0 for shared KV heads
- Batched: blockIdx.z for batch dimension
- LSE output: per-row LSE for multi-segment KV merge
- FmhaParams struct: stride-based tensor addressing, future-proof for GQA
- Multi-row softmax T≤32: cos 0.999996+ with per-lane per-row softmax (no wmax/wsum)
Architecture:
Warp 0-3 (tid 0-127): Softmax + correction + epilogue
- Read S from TMEM → softmax → write P to SMEM
- After PV: read O from TMEM → BF16 → GMEM
- T=1 decode: only warp 0 processes row 0
Warp 4 (tid 128-159): MMA
- tcgen05.mma SS for QK (N=128) and PV (N=16 sub-tiles)
- TMEM alloc/dealloc
Warp 5 (tid 160-191): Data staging
- Load Q/K/V from GMEM to SMEM (canonical layout)
- Fill sPk from s_p_vals
Next milestones:
- TMA loads (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor
- Requires CUtensorMap creation on host
- mbarrier synchronization
- BLOCKED: cuTensorMapEncodeTiled 2D/3D/5D returns INVALID_VALUE on B200 driver v580.126.20
- Alternative: Study CuTeDSL's TMA descriptor creation source code
- Pipeline overlap (Milestone 3): Double-buffer K/V loads
- Load next K/V while computing current QK
- mbarrier producer-consumer sync between warp 5 and warp 4
- Depends on TMA loads (Milestone 2)
- Multi-row softmax (Milestone 4): Process all 128 rows (prefill T>1) 🚧 IN PROGRESS
- T≤32: WORKING — warp 0, lane l handles row l, 32x32b.x8 TMEM reads
- T>32: BLOCKED — 32x32b.x8 only reads rows 0-31
- NEXT: Use 16x256b.x1 TMEM reads (reads all 128 rows per column)
- Each of 4 softmax warps handles rows [w32, (w+132) ∩ [0, T)
- Per-lane row assignment in 16x256b: lane j gets rows j*4+0..3
- No cross-warp reduction needed (disjoint row sets)
- KEY LESSON: Q/K SMEM canonical positions MUST use local d (0..15), NOT full_d The UMMA descriptor always reads from sQ0/sK0 start, not offset
Multi-head launch (Milestone 5): grid=(1, n_h, batch)✅ DONE- Production integration (Milestone 6): Hook into production.py
Files:
dsv4/kernels/attention/fmha_6warp.cuh— 6-warp kernel (single-head)dsv4/kernels/attention/fmha_6warp_multihead.cuh— Multi-head grid launch kerneltests/unit/test_fmha_6warp_multihead.cu— Multi-head test harnesstests/unit/test_fmha_6warp_multihead_hd{16,64,128,256}.cu— HD-specific wrappers
Layout D N=64 Bug (documented for NVIDIA):
- tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51
- Workaround: N=16 sub-tiles with TMEM offset n*16
Mike, I've hit a wall with TMA. Here's the situation:
-
Raw
cuTensorMapEncodeTileddoesn't work for ANY multi-dimensional descriptor on the B200. 1D works, 2D/3D/5D all return INVALID_VALUE. This might be a driver issue (v580.126.20) or a parameter format I haven't figured out. -
CuTeDSL's TMA works perfectly (the existing FMHA kernel uses it), but I can't mix it with raw CUDA inline PTX easily.
-
The CuTeDSL FMHA with pv_n_tile=16 crashes (illegal memory access) — needs CuTeDSL debugging.
What I recommend: focus on what works. The raw CUDA 6-warp kernel is fully working at all HD values with direct GMEM reads. The TMA optimization can wait for:
- A driver update that fixes
cuTensorMapEncodeTiledfor multi-dimensional descriptors, OR - Debugging the CuTeDSL pv_n_tile=16 crash (CuTeDSL-specific, different skill set), OR
- Using CuTeDSL's TMA in a separate staging kernel
Should I:
- A) Continue debugging the CuTeDSL pv_n_tile=16 crash (might take a while, but gives us TMA + correct PV)
- B) Move on to multi-head launch (more immediate production impact, no TMA needed)
- C) Try to find the correct
cuTensorMapEncodeTiledparameters by studying the CuTeDSL source code that creates the descriptors