Files
nvfp4-megamoe-kernel/STAGE_D.md

12 KiB
Raw Blame History

STAGE_D.md — FMHA Kernel Development

⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING

The Workflow (DO NOT SKIP STEPS)

  1. Edit code in ~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py — this is the ONLY file for the FMHA kernel.
  2. Commit and push:
    cd ~/dev/nvfp4-megamoe-kernel
    git add -A && git commit -m "description" && git push origin master
    
  3. Pull on B200:
    sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
      "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master"
    
  4. Test on B200 using the test harness scripts — see README.md "Test Harness" section.
  5. Regression check: After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.

The Rules (BURNED INTO THIS FILE)

  • NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
  • NEVER delete or modify the test files in tests/unit/ without explicit approval.
  • NEVER touch drivers, kernels, firmware, or system packages on the B200.
  • CuTeDSL variables defined in if blocks are NOT visible in other if blocks. Define all variables unconditionally before any branching.
  • Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
  • After every P store to TMEM, call cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.
  • tOrP0 MUST include the tmem_p0_offset column offset. Use const_expr for the conditional.
  • PRINT THE SHAPES. ALWAYS. Reasoning about layouts without evidence is how we waste days.

Current Status (2026-05-24, 21:30 UTC)

WORKING

hd n=128 cos LSE err Path SMEM
64 0.999998 0.000000 TMEM-P 128KB
128 0.999997 0.000000 TMEM-P / SMEM-P 128KB
256 0.999998 0.000000 TMEM-P 224KB

KNOWN ISSUES

  • hd=512: MLIR compilation hangs. SMEM budget fixed (192KB ), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both range() unrolled and cutlass.range(unroll=1) runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation.
  • External k_sub merge doesn't work. k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach.
  • O rescale (kt>0): Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with const_expr(n_kv_tiles > 1).
  • Kernel always outputs un-normalized O + LSE. No in-kernel normalization (eliminates TMEM round-trip error). External normalization: O_norm = O_unnorm / row_sum.

Architecture

6-Warp Layout

Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)

Kernel Output

The kernel outputs un-normalized O + LSE via epilogue_tma_store:

  • O_unnorm = sum(P * V) where P = exp(S * scale - row_max)
  • LSE = ln(row_sum) + row_max * ln(2)
  • External normalization: O_norm = O_unnorm / row_sum
  • For D5 merge: use exp(LSE) directly in the merge formula

TMEM Layout

Col 0-31:   S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95:  P (64 FP32 via register bridge, BF16 view)
Col 128+:   O (PV acc, 64+ FP32)

P Staging Paths

TMEM-P (hd≤64, also works at hd=128/256):

  • P stored to TMEM via register bridge (FP32 backing + BF16 view)
  • PV MMA reads P from TMEM via tOrP0
  • Works because QK C-fragment and PV A-fragment TMEM layouts agree at tested head dims

SMEM-P (hd>64):

  • P written to SMEM via coordinate-indexed store
  • Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates
  • Maps to sP's subtile layout: sP[(m_coord, k_sub), 0, (k_g1, k_g2)]
  • PV MMA reads P from SMEM via tCrP = pv_mma.make_fragment_A(sP)
  • SMEM-P uses OperandSource.SMEM for PV MMA

Key Configuration

head_dim:     constructor arg (64, 128, 256, 512)
pv_n_tile:    min(head_dim, 256)  # tcgen05 MMA max N=256
n_pv_tiles:   head_dim // pv_n_tile
kv_stage:     1 if head_dim > 128 else 2  # Reduce SMEM at large hd
use_smem_p:   head_dim > 64  # SMEM-P for hd>64
qk_mma_tiler: (128, 128, head_dim)  # K-dim = head_dim (NOT hardcoded!)

Critical Bug Fix: qk_mma_tiler K-dim (2026-05-24)

ROOT CAUSE of hd>64 failure: qk_mma_tiler K-dim was hardcoded to qk_ik * 4 = 64 instead of head_dim.

This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd>64. The QK dot products were half the correct length, producing wrong attention scores.

Fix: self.qk_mma_tiler = (128, 128, self.head_dim) — one line change.

Impact: hd=128 went from cos 0.78 to 0.999997. hd=256 went from broken to 0.999998.

LESSON: The MMA tiler's K dimension must match the actual GEMM K dimension (head_dim), not the MMA instruction's K sub-tile size.


Lessons Learned (2026-05-24)

1. CuTeDSL MLIR Backend Cannot Handle Complex Pipeline Loops

The MLIR→PTX backend optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both unrolled (Python range) and runtime (cutlass.range unroll=1) loops trigger this. The Python tracer is fast (0.8s) because it just generates IR. The MLIR optimizer then chews on that IR for hours. Workaround: keep pipeline loops as simple as possible. Consider raw CUDA C++ for complex kernels.

2. External k_sub Merge is Mathematically Impossible

You CANNOT merge the outputs of two attention calls that compute softmax(Q_k0 @ K_k0^T)@V and softmax(Q_k1 @ K_k1^T)@V into softmax(Q @ K^T)@V. The k_sub segments are additive in LOGIT space (S = S_0 + S_1), but softmax is nonlinear. The D5 merge formula works because sparse and SWA attend over DIFFERENT token sets (additive in weight space). k_sub attends over the SAME tokens with PARTIAL dot products. These are fundamentally different operations. The only correct approach is in-kernel accumulation (S_0 + S_1 before softmax).

3. pv_n_tile Reduction is the Easiest SMEM Knob

At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. The cost is 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck. This is simpler than SMEM overlap (which requires CuTeDSL SmemAllocator changes) or Q tiling (which adds pipeline complexity).

4. Guard Dead Code with const_expr

CuTeDSL compiles BOTH branches of Python if statements, generating IR for code that will never execute at a given head_dim. Use const_expr(condition) to eliminate dead code at compile time. This is critical for:

  • O rescale code (only needed when n_kv_tiles > 1)
  • LSE computation (only needed when normalize=False)
  • SMEM-P path (only needed when use_smem_p=True)

5. Don't Mix Python Loops and CuTeDSL Pipeline Operations

Python for loops unroll at trace time, creating N copies of the loop body in the IR. For pipeline acquire/release + TMA copy + GEMM, each copy is substantial. cutlass.range(unroll=1) creates a runtime loop with one copy of the body. For pipeline operations, prefer cutlass.range(unroll=1) to reduce IR size, even though the MLIR optimizer may still struggle with it.

6. The k_tile Parameter is the Key to hd=512

At hd=512, the kernel splits Q and K into sub-tiles of size k_tile=256 along the head_dim. Each sub-tile is loaded via TMA, processed by MMA, and accumulated. n_k_sub_tiles = head_dim // k_tile = 2. The k_tile parameter controls the sub-tile size and the number of iterations. k_tile must be ≤ 256 (MMA instruction K-dim limit) and must evenly divide head_dim.


SMEM Budget at Various hd

hd sQ sK (kv_stage=1) sV (pv_n_tile) sP (SMEM-P) sC Total Limit Status
64 32KB 32KB 32KB (256) 32KB 128KB 232KB
128 32KB 32KB 32KB (256) 32KB 128KB 232KB
256 64KB 64KB 64KB (256) 0* 32KB 224KB 232KB
512 64KB 64KB 32KB (128) 0* 32KB 192KB 232KB ⚠️ Fits but MLIR hangs

*TMEM-P path: sP allocation skipped (const_expr conditional) pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit SMEM


D1.5: Correction Epilogue (TMEM Round-Trip Error)

Issue: Hand-constructed Ld32x32bOp/St32x32bOp atoms don't preserve the C-fragment layout during TMEM round-trips (load→modify→store). Causes ~3% error per round-trip.

Current workaround: Kernel outputs un-normalized O + LSE. No in-kernel normalization needed. External normalization is exact.

Proper fix (future): Use CUTLASS epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern with paired atoms. One-way trip: TMEM → registers (normalize) → SMEM → GMEM.

Priority: MEDIUM. Not a correctness blocker (external normalization is exact). Would enable in-kernel normalization for D5c/D5d.


Build Order (Remaining)

D1.4 — hd=512 CURRENT (BLOCKED)

Problem: hd=512 exceeds the MMA instruction's max K-dim (256). Must split Q and K into 2 sub-tiles along head_dim (k_tile=256, n_k_sub_tiles=2). The QK dot product is S = Q_k0 @ K_k0^T + Q_k1 @ K_k1^T (additive in logit space).

SMEM budget: SOLVED. pv_n_tile=128 for hd>256 reduces sV from 64KB→32KB, sC from 64KB→32KB. Total 192KB .

Compilation: BLOCKED. The CuTeDSL MLIR→PTX backend optimizer cannot compile the hd=512 kernel in reasonable time. Both Python range() (unrolled IR) and cutlass.range(unroll=1) (runtime loop) produce IR that the optimizer chews on for 3+ hours without finishing. The Python tracer completes in 0.8s — the kernel is structurally correct. This is a toolchain limitation.

External merge: IMPOSSIBLE. The D5 online softmax merge formula assumes separate attention distributions over different token sets (additive in weight space). k_sub segments are additive in LOGIT space (S = S_0 + S_1), not weight space. You cannot recover softmax(S_0 + S_1)@V from softmax(S_0)@V and softmax(S_1)@V. In-kernel accumulation before softmax is the only correct approach.

Bug fixes applied along the way:

  • LSE type mismatch (BF16 vs FP32 when normalize=True) → guarded with const_expr(not self.normalize)
  • O rescale IR explosion at n=128 → guarded with const_expr(n_kv_tiles > 1)
  • k_sub tracer IR explosion → replaced hardcoded if k_sub==0/1 with Python range() loop
  • External merge test (cos 0.617) → confirmed mathematically impossible, deleted approach

Possible paths forward (priority order):

  1. Pre-compile hd=512 kernel offline. Accept 1-2 hour compilation during build. Cache the cubin. This works if the MLIR optimizer eventually finishes (it might just be slow, not stuck — but 3+ hours is excessive even for pre-compilation).
  2. Add no-softmax mode to the kernel. Output raw S (QK scores) without softmax. Call twice for k_sub=0 and k_sub=1. Accumulate S_0+S_1 in Python. Apply softmax once. This requires modifying the softmax warp to optionally skip normalization and output S to GMEM instead of P to TMEM/SMEM.
  3. Write hd=512 kernel in CUTLASS C++. Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds).
  4. Report CuTeDSL MLIR optimizer bug. The optimizer should handle this IR in reasonable time. File an issue with NVIDIA.

D2 — Multi-Query Grid with Head Packing

  • Grid changes from (1, 1, 1) to (num_q_blocks, 1, batch)
  • DSV4 is MQA: all 128 query heads share same K/V
  • Head axis folded into M dimension: M_tile = 128 covers M = T * n_h rows

D3 — SWA Sequence Length Mask

  • Add swa_lens: [batch] int32 kernel input
  • Mask SWA-branch logits to -inf where swa_idx >= swa_lens[b]

D4 — Causal Mask on SWA Branch

  • Add is_causal: bool constructor flag
  • Apply swa_idx > q_pos masking to -inf in SWA pass

D5 — SWA + Sink Merge

  • D5a : Kernel outputs un-normalized O + LSE
  • D5b : Python merge works (cos 0.961 at hd=64)
  • D5c: Fuse two passes into one kernel launch
  • D5d: Fuse sink merge into kernel epilogue