# STAGE_D.md — FMHA Kernel Development ## ⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING ### The Workflow (DO NOT SKIP STEPS) 1. **Edit code in** `~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py` — this is the ONLY file for the FMHA kernel. 2. **Commit and push:** ```bash cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master ``` 3. **Pull on B200:** ```bash sshpass -p '' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" ``` 4. **Test on B200 using the test harness scripts** — see README.md "Test Harness" section. 5. **Regression check:** After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert. ### The Rules (BURNED INTO THIS FILE) - **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time. - **NEVER delete or modify the test files in `tests/unit/`** without explicit approval. - **NEVER touch drivers, kernels, firmware, or system packages on the B200.** - **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Define all variables unconditionally before any branching. - **Always test at hd=64 FIRST.** If the proven path (TMEM-P) regresses, nothing else matters. - **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN. - **`tOrP0` MUST include the `tmem_p0_offset` column offset.** Use `const_expr` for the conditional. - **PRINT THE SHAPES. ALWAYS.** Reasoning about layouts without evidence is how we waste days. --- ## Current Status (2026-05-26, 18:40 UTC) ### ✅ WORKING | hd | n=128 cos | LSE err | Path | SMEM | |---:|----------:|--------:|------|-----:| | 64 | 0.999998 | 0.000000 | TMEM-P | 128KB | | 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB | | 256 | 0.999998 | 0.000000 | TMEM-P | 224KB | ### ✅ D5 COMPLETE (May 26) | Test | Config | cos | Status | |------|--------|-----|--------| | D5c single-tile | n_comp=64, n_swa=64, sink=0.5 | 0.999996 | ✅ | | D5c causal | n_comp=64, n_swa=64, sink=0.3, causal | 0.999996 | ✅ | | D5c multi-tile | n_comp=96, n_swa=160, s_k=256, Python KV merge | 0.999996 | ✅ | | D3 regression | In-kernel mask, s_k=128 | 0.999996 | ✅ | | D4 regression | Causal mask, s_k=128 | 0.999996 | ✅ | ### ❌ KNOWN ISSUES - **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). This is a CuTeDSL/MLIR toolchain limitation. - **D1.5 O rescale (multi-KV-tile): TMEM round-trip corruption.** Hand-constructed Ld32x32bOp/St32x32bOp atoms corrupt data on round-trip (even NO-OP). Workaround: Python KV merge (cos 0.999994). Fix requires correction epilog pattern (one-way TMEM→regs→SMEM→GMEM). - **D2 multi-CTA grid: flat_divide + epilogue_tma_store layout mismatch.** Requires full tma_partition refactor into kernel. Head-packed per-head launch works (cos 0.999995). --- ## Architecture ### 6-Warp Layout ``` Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale) Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load) ``` ### Kernel Output The kernel outputs **un-normalized O + LSE** via `epilogue_tma_store`: - O_unnorm = sum(P * V) where P = exp(S * scale - row_max) - LSE = ln(row_sum) + row_max * ln(2) - External normalization: O_norm = O_unnorm / row_sum - For D5 merge: use exp(LSE) directly in the merge formula ### TMEM Layout ``` Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32)) Col 32-95: P (64 FP32 via register bridge, BF16 view) Col 128+: O (PV acc, 64+ FP32) ``` ### P Staging Paths **TMEM-P (hd≤64, also works at hd=128/256):** - P stored to TMEM via register bridge (FP32 backing + BF16 view) - PV MMA reads P from TMEM via `tOrP0` - Works because QK C-fragment and PV A-fragment TMEM layouts agree at tested head dims **SMEM-P (hd>64):** - P written to SMEM via coordinate-indexed store - Uses `tTMEM_LOADcS` identity tensor to get (m, k) coordinates - Maps to sP's subtile layout: `sP[(m_coord, k_sub), 0, (k_g1, k_g2)]` - PV MMA reads P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)` - SMEM-P uses `OperandSource.SMEM` for PV MMA ### Key Configuration ```python head_dim: constructor arg (64, 128, 256, 512) pv_n_tile: min(head_dim, 256) # tcgen05 MMA max N=256 n_pv_tiles: head_dim // pv_n_tile kv_stage: 1 if head_dim > 128 else 2 # Reduce SMEM at large hd use_smem_p: head_dim > 64 # SMEM-P for hd>64 qk_mma_tiler: (128, 128, head_dim) # K-dim = head_dim (NOT hardcoded!) ``` --- ## Critical Bug Fix: qk_mma_tiler K-dim (2026-05-24) **ROOT CAUSE of hd>64 failure:** `qk_mma_tiler` K-dim was hardcoded to `qk_ik * 4 = 64` instead of `head_dim`. This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd>64. The QK dot products were half the correct length, producing wrong attention scores. **Fix:** `self.qk_mma_tiler = (128, 128, self.head_dim)` — one line change. **Impact:** hd=128 went from cos 0.78 to 0.999997. hd=256 went from broken to 0.999998. **LESSON:** The MMA tiler's K dimension must match the actual GEMM K dimension (head_dim), not the MMA instruction's K sub-tile size. --- ## Lessons Learned (2026-05-24) ### 1. CuTeDSL MLIR Backend Cannot Handle Complex Pipeline Loops The MLIR→PTX backend optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger this. The Python tracer is fast (0.8s) because it just generates IR. The MLIR optimizer then chews on that IR for hours. **Workaround:** keep pipeline loops as simple as possible. Consider raw CUDA C++ for complex kernels. ### 2. External k_sub Merge is Mathematically Impossible You CANNOT merge the outputs of two attention calls that compute softmax(Q_k0 @ K_k0^T)@V and softmax(Q_k1 @ K_k1^T)@V into softmax(Q @ K^T)@V. The k_sub segments are additive in LOGIT space (S = S_0 + S_1), but softmax is nonlinear. The D5 merge formula works because sparse and SWA attend over DIFFERENT token sets (additive in weight space). k_sub attends over the SAME tokens with PARTIAL dot products. These are fundamentally different operations. **The only correct approach is in-kernel accumulation (S_0 + S_1 before softmax).** ### 3. pv_n_tile Reduction is the Easiest SMEM Knob At hd>256, reducing `pv_n_tile` from 256 to 128 shrinks sV and sC by 2× each. The cost is 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck. This is simpler than SMEM overlap (which requires CuTeDSL SmemAllocator changes) or Q tiling (which adds pipeline complexity). ### 4. Guard Dead Code with const_expr CuTeDSL compiles BOTH branches of Python `if` statements, generating IR for code that will never execute at a given head_dim. Use `const_expr(condition)` to eliminate dead code at compile time. This is critical for: - O rescale code (only needed when n_kv_tiles > 1) - LSE computation (only needed when normalize=False) - SMEM-P path (only needed when use_smem_p=True) ### 5. Don't Mix Python Loops and CuTeDSL Pipeline Operations Python `for` loops unroll at trace time, creating N copies of the loop body in the IR. For pipeline acquire/release + TMA copy + GEMM, each copy is substantial. `cutlass.range(unroll=1)` creates a runtime loop with one copy of the body. **For pipeline operations, prefer `cutlass.range(unroll=1)` to reduce IR size**, even though the MLIR optimizer may still struggle with it. ### 6. The k_tile Parameter is the Key to hd=512 At hd=512, the kernel splits Q and K into sub-tiles of size `k_tile=256` along the head_dim. Each sub-tile is loaded via TMA, processed by MMA, and accumulated. `n_k_sub_tiles = head_dim // k_tile = 2`. The k_tile parameter controls the sub-tile size and the number of iterations. **k_tile must be ≤ 256** (MMA instruction K-dim limit) and must evenly divide head_dim. --- ## SMEM Budget at Various hd | hd | sQ | sK (kv_stage=1) | sV (pv_n_tile) | sP (SMEM-P) | sC | Total | Limit | Status | |---:|----:|----:|----:|----:|----:|------:|------:|--------| | 64 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ | | 128 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ | | 256 | 64KB | 64KB | 64KB (256) | 0* | 32KB | 224KB | 232KB | ✅ | | 512 | 64KB | 64KB | 32KB (128) | 0* | 32KB | 192KB | 232KB | ⚠️ Fits but MLIR hangs | *TMEM-P path: sP allocation skipped (const_expr conditional) pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit SMEM --- ## D1.5: Multi-KV-Tile O Rescale — TMEM Round-Trip Fundamentally Broken (2026-05-26) **Root cause:** `Ld32x32bOp` and `St32x32bOp` have DIFFERENT column mappings at the hardware level. No layout transformation can fix this. **Investigation (2026-05-26):** - Attempted paired-atom approach using `epilogue_tmem_copy_and_partition` for TMEM load + `retile_to_S()` for TMEM store — `retile_to_S()` doesn't exist on TiledCopy - Even if API worked, the round-trip is fundamentally broken: load and store atoms have different column mappings - Full correction epilogue (one-way TMEM→REGS→SMEM→GMEM) causes 20+ minute MLIR compilation hang - The compilation hang is from `transform_partitioned_tensor_layout` + `flat_divide` + `tma_partition` inside the softmax warp block **Production path:** Python KV merge (cos 0.999998 for s_k up to 1024). **Future fix:** Restructure PV to accumulate into REGS/SMEM instead of TMEM. Enables one-way correction epilogue. Major refactor (1-2 days). --- ## Build Order (Remaining) ### D1.4 — hd=512 ⚡ CURRENT (BLOCKED) **Problem:** hd=512 exceeds the MMA instruction's max K-dim (256). Must split Q and K into 2 sub-tiles along head_dim (k_tile=256, n_k_sub_tiles=2). The QK dot product is S = Q_k0 @ K_k0^T + Q_k1 @ K_k1^T (additive in logit space). **SMEM budget: SOLVED.** pv_n_tile=128 for hd>256 reduces sV from 64KB→32KB, sC from 64KB→32KB. Total 192KB ✅. **Compilation: BLOCKED.** The CuTeDSL MLIR→PTX backend optimizer cannot compile the hd=512 kernel in reasonable time. Both Python `range()` (unrolled IR) and `cutlass.range(unroll=1)` (runtime loop) produce IR that the optimizer chews on for 3+ hours without finishing. The Python tracer completes in 0.8s — the kernel is structurally correct. This is a toolchain limitation. **External merge: IMPOSSIBLE.** The D5 online softmax merge formula assumes separate attention distributions over different token sets (additive in weight space). k_sub segments are additive in LOGIT space (S = S_0 + S_1), not weight space. You cannot recover softmax(S_0 + S_1)@V from softmax(S_0)@V and softmax(S_1)@V. In-kernel accumulation before softmax is the only correct approach. **Bug fixes applied along the way:** - LSE type mismatch (BF16 vs FP32 when normalize=True) → guarded with `const_expr(not self.normalize)` - O rescale IR explosion at n=128 → guarded with `const_expr(n_kv_tiles > 1)` - k_sub tracer IR explosion → replaced hardcoded `if k_sub==0/1` with Python `range()` loop - External merge test (cos 0.617) → confirmed mathematically impossible, deleted approach **Possible paths forward (priority order):** 1. **Pre-compile hd=512 kernel offline.** Accept 1-2 hour compilation during build. Cache the cubin. This works if the MLIR optimizer eventually finishes (it might just be slow, not stuck — but 3+ hours is excessive even for pre-compilation). 2. **Add no-softmax mode to the kernel.** Output raw S (QK scores) without softmax. Call twice for k_sub=0 and k_sub=1. Accumulate S_0+S_1 in Python. Apply softmax once. This requires modifying the softmax warp to optionally skip normalization and output S to GMEM instead of P to TMEM/SMEM. 3. **Write hd=512 kernel in CUTLASS C++.** Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds). 4. **Report CuTeDSL MLIR optimizer bug.** The optimizer should handle this IR in reasonable time. File an issue with NVIDIA. ### D2 — Multi-Query Grid with Head Packing ✅ (per-head launch) - Head-packed M-dimension launch: Q reshaped to (n_h*T, hd, 1), kernel treats each row independently - cos 0.999995 for n_h=1-128 at hd=64, n_h=2-8 at hd=128, n_h=2 at hd=256 - Multi-CTA grid (flat_divide) BLOCKED — see Known Issues ### D3 — SWA Sequence Length Mask ✅ - In-kernel post-QK masking via tTMEM_LOADcS coordinates - swa_len as Int32 scalar (runtime, not compile-time) - Offset by n_comp for D5c: mask positions >= n_comp + swa_len ### D4 — Causal Mask on SWA Branch ✅ - SWA-relative position (kv_pos - n_comp) > m_coord → -inf - Combined with D3 via OR logic ### D5 — SWA + Sink Merge ✅ (May 26) **Key insight:** Sink merge = single softmax over [S_comp, S_swa + attn_sink]. One pass, one kernel. D5d NOT NEEDED. - **D5a ✅:** normalize flag + LSE + row_sums output - **D5b ✅:** Per-row LSE + Python KV merge (cos 0.999994) - **D5c ✅:** Sink bias as logit modification (cos 0.999996 single-tile AND multi-tile) - **D5d:** NOT NEEDED — sink bias approach supersedes fused merge epilogue