12 KiB
STAGE_D.md — FMHA Kernel Development
⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
The Workflow (DO NOT SKIP STEPS)
- Edit code in
~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py— this is the ONLY file for the FMHA kernel. - Commit and push:
cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master - Pull on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" - Test on B200 using the test harness scripts — see README.md "Test Harness" section.
- Regression check: After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.
The Rules (BURNED INTO THIS FILE)
- NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
- NEVER delete or modify the test files in
tests/unit/without explicit approval. - NEVER touch drivers, kernels, firmware, or system packages on the B200.
- CuTeDSL variables defined in
ifblocks are NOT visible in otherifblocks. Define all variables unconditionally before any branching. - Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
- After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN. tOrP0MUST include thetmem_p0_offsetcolumn offset. Useconst_exprfor the conditional.- PRINT THE SHAPES. ALWAYS. Reasoning about layouts without evidence is how we waste days.
Current Status (2026-05-24, 21:30 UTC)
✅ WORKING
| hd | n=128 cos | LSE err | Path | SMEM |
|---|---|---|---|---|
| 64 | 0.999998 | 0.000000 | TMEM-P | 128KB |
| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB |
| 256 | 0.999998 | 0.000000 | TMEM-P | 224KB |
❌ KNOWN ISSUES
- hd=512: MLIR compilation hangs. SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both
range()unrolled andcutlass.range(unroll=1)runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation. - External k_sub merge doesn't work. k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach.
- O rescale (kt>0): Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with
const_expr(n_kv_tiles > 1). - Kernel always outputs un-normalized O + LSE. No in-kernel normalization (eliminates TMEM round-trip error). External normalization:
O_norm = O_unnorm / row_sum.
Architecture
6-Warp Layout
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
Kernel Output
The kernel outputs un-normalized O + LSE via epilogue_tma_store:
- O_unnorm = sum(P * V) where P = exp(S * scale - row_max)
- LSE = ln(row_sum) + row_max * ln(2)
- External normalization: O_norm = O_unnorm / row_sum
- For D5 merge: use exp(LSE) directly in the merge formula
TMEM Layout
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95: P (64 FP32 via register bridge, BF16 view)
Col 128+: O (PV acc, 64+ FP32)
P Staging Paths
TMEM-P (hd≤64, also works at hd=128/256):
- P stored to TMEM via register bridge (FP32 backing + BF16 view)
- PV MMA reads P from TMEM via
tOrP0 - Works because QK C-fragment and PV A-fragment TMEM layouts agree at tested head dims
SMEM-P (hd>64):
- P written to SMEM via coordinate-indexed store
- Uses
tTMEM_LOADcSidentity tensor to get (m, k) coordinates - Maps to sP's subtile layout:
sP[(m_coord, k_sub), 0, (k_g1, k_g2)] - PV MMA reads P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - SMEM-P uses
OperandSource.SMEMfor PV MMA
Key Configuration
head_dim: constructor arg (64, 128, 256, 512)
pv_n_tile: min(head_dim, 256) # tcgen05 MMA max N=256
n_pv_tiles: head_dim // pv_n_tile
kv_stage: 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
use_smem_p: head_dim > 64 # SMEM-P for hd>64
qk_mma_tiler: (128, 128, head_dim) # K-dim = head_dim (NOT hardcoded!)
Critical Bug Fix: qk_mma_tiler K-dim (2026-05-24)
ROOT CAUSE of hd>64 failure: qk_mma_tiler K-dim was hardcoded to qk_ik * 4 = 64 instead of head_dim.
This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd>64. The QK dot products were half the correct length, producing wrong attention scores.
Fix: self.qk_mma_tiler = (128, 128, self.head_dim) — one line change.
Impact: hd=128 went from cos 0.78 to 0.999997. hd=256 went from broken to 0.999998.
LESSON: The MMA tiler's K dimension must match the actual GEMM K dimension (head_dim), not the MMA instruction's K sub-tile size.
Lessons Learned (2026-05-24)
1. CuTeDSL MLIR Backend Cannot Handle Complex Pipeline Loops
The MLIR→PTX backend optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both unrolled (Python range) and runtime (cutlass.range unroll=1) loops trigger this. The Python tracer is fast (0.8s) because it just generates IR. The MLIR optimizer then chews on that IR for hours. Workaround: keep pipeline loops as simple as possible. Consider raw CUDA C++ for complex kernels.
2. External k_sub Merge is Mathematically Impossible
You CANNOT merge the outputs of two attention calls that compute softmax(Q_k0 @ K_k0^T)@V and softmax(Q_k1 @ K_k1^T)@V into softmax(Q @ K^T)@V. The k_sub segments are additive in LOGIT space (S = S_0 + S_1), but softmax is nonlinear. The D5 merge formula works because sparse and SWA attend over DIFFERENT token sets (additive in weight space). k_sub attends over the SAME tokens with PARTIAL dot products. These are fundamentally different operations. The only correct approach is in-kernel accumulation (S_0 + S_1 before softmax).
3. pv_n_tile Reduction is the Easiest SMEM Knob
At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. The cost is 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck. This is simpler than SMEM overlap (which requires CuTeDSL SmemAllocator changes) or Q tiling (which adds pipeline complexity).
4. Guard Dead Code with const_expr
CuTeDSL compiles BOTH branches of Python if statements, generating IR for code that will never execute at a given head_dim. Use const_expr(condition) to eliminate dead code at compile time. This is critical for:
- O rescale code (only needed when n_kv_tiles > 1)
- LSE computation (only needed when normalize=False)
- SMEM-P path (only needed when use_smem_p=True)
5. Don't Mix Python Loops and CuTeDSL Pipeline Operations
Python for loops unroll at trace time, creating N copies of the loop body in the IR. For pipeline acquire/release + TMA copy + GEMM, each copy is substantial. cutlass.range(unroll=1) creates a runtime loop with one copy of the body. For pipeline operations, prefer cutlass.range(unroll=1) to reduce IR size, even though the MLIR optimizer may still struggle with it.
6. The k_tile Parameter is the Key to hd=512
At hd=512, the kernel splits Q and K into sub-tiles of size k_tile=256 along the head_dim. Each sub-tile is loaded via TMA, processed by MMA, and accumulated. n_k_sub_tiles = head_dim // k_tile = 2. The k_tile parameter controls the sub-tile size and the number of iterations. k_tile must be ≤ 256 (MMA instruction K-dim limit) and must evenly divide head_dim.
SMEM Budget at Various hd
| hd | sQ | sK (kv_stage=1) | sV (pv_n_tile) | sP (SMEM-P) | sC | Total | Limit | Status |
|---|---|---|---|---|---|---|---|---|
| 64 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ |
| 128 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ |
| 256 | 64KB | 64KB | 64KB (256) | 0* | 32KB | 224KB | 232KB | ✅ |
| 512 | 64KB | 64KB | 32KB (128) | 0* | 32KB | 192KB | 232KB | ⚠️ Fits but MLIR hangs |
*TMEM-P path: sP allocation skipped (const_expr conditional) pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit SMEM
D1.5: Correction Epilogue (TMEM Round-Trip Error)
Issue: Hand-constructed Ld32x32bOp/St32x32bOp atoms don't preserve the C-fragment layout during TMEM round-trips (load→modify→store). Causes ~3% error per round-trip.
Current workaround: Kernel outputs un-normalized O + LSE. No in-kernel normalization needed. External normalization is exact.
Proper fix (future): Use CUTLASS epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern with paired atoms. One-way trip: TMEM → registers (normalize) → SMEM → GMEM.
Priority: MEDIUM. Not a correctness blocker (external normalization is exact). Would enable in-kernel normalization for D5c/D5d.
Build Order (Remaining)
D1.4 — hd=512 ⚡ CURRENT (BLOCKED)
Problem: hd=512 exceeds the MMA instruction's max K-dim (256). Must split Q and K into 2 sub-tiles along head_dim (k_tile=256, n_k_sub_tiles=2). The QK dot product is S = Q_k0 @ K_k0^T + Q_k1 @ K_k1^T (additive in logit space).
SMEM budget: SOLVED. pv_n_tile=128 for hd>256 reduces sV from 64KB→32KB, sC from 64KB→32KB. Total 192KB ✅.
Compilation: BLOCKED. The CuTeDSL MLIR→PTX backend optimizer cannot compile the hd=512 kernel in reasonable time. Both Python range() (unrolled IR) and cutlass.range(unroll=1) (runtime loop) produce IR that the optimizer chews on for 3+ hours without finishing. The Python tracer completes in 0.8s — the kernel is structurally correct. This is a toolchain limitation.
External merge: IMPOSSIBLE. The D5 online softmax merge formula assumes separate attention distributions over different token sets (additive in weight space). k_sub segments are additive in LOGIT space (S = S_0 + S_1), not weight space. You cannot recover softmax(S_0 + S_1)@V from softmax(S_0)@V and softmax(S_1)@V. In-kernel accumulation before softmax is the only correct approach.
Bug fixes applied along the way:
- LSE type mismatch (BF16 vs FP32 when normalize=True) → guarded with
const_expr(not self.normalize) - O rescale IR explosion at n=128 → guarded with
const_expr(n_kv_tiles > 1) - k_sub tracer IR explosion → replaced hardcoded
if k_sub==0/1with Pythonrange()loop - External merge test (cos 0.617) → confirmed mathematically impossible, deleted approach
Possible paths forward (priority order):
- Pre-compile hd=512 kernel offline. Accept 1-2 hour compilation during build. Cache the cubin. This works if the MLIR optimizer eventually finishes (it might just be slow, not stuck — but 3+ hours is excessive even for pre-compilation).
- Add no-softmax mode to the kernel. Output raw S (QK scores) without softmax. Call twice for k_sub=0 and k_sub=1. Accumulate S_0+S_1 in Python. Apply softmax once. This requires modifying the softmax warp to optionally skip normalization and output S to GMEM instead of P to TMEM/SMEM.
- Write hd=512 kernel in CUTLASS C++. Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds).
- Report CuTeDSL MLIR optimizer bug. The optimizer should handle this IR in reasonable time. File an issue with NVIDIA.
D2 — Multi-Query Grid with Head Packing
- Grid changes from
(1, 1, 1)to(num_q_blocks, 1, batch) - DSV4 is MQA: all 128 query heads share same K/V
- Head axis folded into M dimension:
M_tile = 128coversM = T * n_hrows
D3 — SWA Sequence Length Mask
- Add
swa_lens: [batch] int32kernel input - Mask SWA-branch logits to
-infwhereswa_idx >= swa_lens[b]
D4 — Causal Mask on SWA Branch
- Add
is_causal: boolconstructor flag - Apply
swa_idx > q_posmasking to-infin SWA pass
D5 — SWA + Sink Merge
- D5a ✅: Kernel outputs un-normalized O + LSE
- D5b ✅: Python merge works (cos 0.961 at hd=64)
- D5c: Fuse two passes into one kernel launch
- D5d: Fuse sink merge into kernel epilogue