6.9 KiB
STAGE_D.md — FMHA Kernel Development
⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
The Workflow (DO NOT SKIP STEPS)
- Edit code in
~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py— this is the ONLY file for the FMHA kernel. - Commit and push:
cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master - Pull on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" - Test on B200 using the test harness scripts — see README.md "Test Harness" section.
- Regression check: After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.
The Rules (BURNED INTO THIS FILE)
- NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
- NEVER delete or modify the test files in
tests/unit/without explicit approval. - NEVER touch drivers, kernels, firmware, or system packages on the B200.
- CuTeDSL variables defined in
ifblocks are NOT visible in otherifblocks. Define all variables unconditionally before any branching. - Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
- After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN. tOrP0MUST include thetmem_p0_offsetcolumn offset. Useconst_exprfor the conditional.- PRINT THE SHAPES. ALWAYS. Reasoning about layouts without evidence is how we waste days.
Current Status (2026-05-24)
✅ WORKING
| hd | n=128 cos | LSE err | Path |
|---|---|---|---|
| 64 | 0.999998 | 0.000000 | TMEM-P |
| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P |
| 256 | 0.999998 | 0.000000 | TMEM-P |
❌ KNOWN ISSUES
- hd=512: SMEM overflow (344KB > 232KB). sQ(128KB) + sK(128KB) + sV(64KB) too large. Needs SMEM tiling or buffer overlap.
- O rescale (kt>0): Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed.
- Kernel always outputs un-normalized O + LSE. No in-kernel normalization (eliminates TMEM round-trip error). External normalization:
O_norm = O_unnorm / row_sum.
Architecture
6-Warp Layout
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
Kernel Output
The kernel outputs un-normalized O + LSE via epilogue_tma_store:
- O_unnorm = sum(P * V) where P = exp(S * scale - row_max)
- LSE = ln(row_sum) + row_max * ln(2)
- External normalization: O_norm = O_unnorm / row_sum
- For D5 merge: use exp(LSE) directly in the merge formula
TMEM Layout
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95: P (64 FP32 via register bridge, BF16 view)
Col 128+: O (PV acc, 64+ FP32)
P Staging Paths
TMEM-P (hd≤64, also works at hd=128/256):
- P stored to TMEM via register bridge (FP32 backing + BF16 view)
- PV MMA reads P from TMEM via
tOrP0 - Works because QK C-fragment and PV A-fragment TMEM layouts agree at tested head dims
SMEM-P (hd>64):
- P written to SMEM via coordinate-indexed store
- Uses
tTMEM_LOADcSidentity tensor to get (m, k) coordinates - Maps to sP's subtile layout:
sP[(m_coord, k_sub), 0, (k_g1, k_g2)] - PV MMA reads P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - SMEM-P uses
OperandSource.SMEMfor PV MMA
Key Configuration
head_dim: constructor arg (64, 128, 256, 512)
pv_n_tile: min(head_dim, 256) # tcgen05 MMA max N=256
n_pv_tiles: head_dim // pv_n_tile
kv_stage: 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
use_smem_p: head_dim > 64 # SMEM-P for hd>64
qk_mma_tiler: (128, 128, head_dim) # K-dim = head_dim (NOT hardcoded!)
Critical Bug Fix: qk_mma_tiler K-dim (2026-05-24)
ROOT CAUSE of hd>64 failure: qk_mma_tiler K-dim was hardcoded to qk_ik * 4 = 64 instead of head_dim.
This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd>64. The QK dot products were half the correct length, producing wrong attention scores.
Fix: self.qk_mma_tiler = (128, 128, self.head_dim) — one line change.
Impact: hd=128 went from cos 0.78 to 0.999997. hd=256 went from broken to 0.999998.
LESSON: The MMA tiler's K dimension must match the actual GEMM K dimension (head_dim), not the MMA instruction's K sub-tile size.
SMEM Budget at Various hd
| hd | sQ | sK (kv_stage=1) | sV (kv_stage=1) | sP (SMEM-P) | sC | Total | Limit | Status |
|---|---|---|---|---|---|---|---|---|
| 64 | 32KB | 32KB | 32KB | 32KB | 32KB | 160KB | 232KB | ✅ |
| 128 | 32KB | 32KB | 32KB | 32KB | 32KB | 160KB | 232KB | ✅ |
| 256 | 64KB | 64KB | 64KB | 0* | 32KB | 224KB | 232KB | ✅ |
| 512 | 128KB | 128KB | 64KB | 0* | 32KB | 352KB | 232KB | ❌ |
*TMEM-P path: sP allocation skipped (const_expr conditional)
D1.5: Correction Epilogue (TMEM Round-Trip Error)
Issue: Hand-constructed Ld32x32bOp/St32x32bOp atoms don't preserve the C-fragment layout during TMEM round-trips (load→modify→store). Causes ~3% error per round-trip.
Current workaround: Kernel outputs un-normalized O + LSE. No in-kernel normalization needed. External normalization is exact.
Proper fix (future): Use CUTLASS epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern with paired atoms. One-way trip: TMEM → registers (normalize) → SMEM → GMEM.
Priority: MEDIUM. Not a correctness blocker (external normalization is exact). Would enable in-kernel normalization for D5c/D5d.
Build Order (Remaining)
D1.4 — hd=512 SMEM Budget ⚡ CURRENT
hd=512 needs sQ(128KB) + sK(128KB) + sV(64KB) = 320KB. Must reduce to fit 232KB.
Options:
- Tile Q along head_dim: Process Q in chunks of 256. Two Q sub-tiles per kernel.
- SMEM buffer overlap: sQ and sK/sV used at different times. After Q is consumed by MMA, reuse sQ's SMEM for K/V.
- Split the GEMM K dimension: Process K in sub-tiles (K=256 then K=256-512). Each sub-tile fits SMEM.
D2 — Multi-Query Grid with Head Packing
- Grid changes from
(1, 1, 1)to(num_q_blocks, 1, batch) - DSV4 is MQA: all 128 query heads share same K/V
- Head axis folded into M dimension:
M_tile = 128coversM = T * n_hrows
D3 — SWA Sequence Length Mask
- Add
swa_lens: [batch] int32kernel input - Mask SWA-branch logits to
-infwhereswa_idx >= swa_lens[b]
D4 — Causal Mask on SWA Branch
- Add
is_causal: boolconstructor flag - Apply
swa_idx > q_posmasking to-infin SWA pass
D5 — SWA + Sink Merge
- D5a ✅: Kernel outputs un-normalized O + LSE
- D5b ✅: Python merge works (cos 0.961 at hd=64)
- D5c: Fuse two passes into one kernel launch
- D5d: Fuse sink merge into kernel epilogue