Files
nvfp4-megamoe-kernel/STAGE_D.md

6.9 KiB

STAGE_D.md — FMHA Kernel Development

⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING

The Workflow (DO NOT SKIP STEPS)

  1. Edit code in ~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py — this is the ONLY file for the FMHA kernel.
  2. Commit and push:
    cd ~/dev/nvfp4-megamoe-kernel
    git add -A && git commit -m "description" && git push origin master
    
  3. Pull on B200:
    sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
      "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master"
    
  4. Test on B200 using the test harness scripts — see README.md "Test Harness" section.
  5. Regression check: After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.

The Rules (BURNED INTO THIS FILE)

  • NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
  • NEVER delete or modify the test files in tests/unit/ without explicit approval.
  • NEVER touch drivers, kernels, firmware, or system packages on the B200.
  • CuTeDSL variables defined in if blocks are NOT visible in other if blocks. Define all variables unconditionally before any branching.
  • Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
  • After every P store to TMEM, call cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.
  • tOrP0 MUST include the tmem_p0_offset column offset. Use const_expr for the conditional.
  • PRINT THE SHAPES. ALWAYS. Reasoning about layouts without evidence is how we waste days.

Current Status (2026-05-24)

WORKING

hd n=128 cos LSE err Path
64 0.999998 0.000000 TMEM-P
128 0.999997 0.000000 TMEM-P / SMEM-P
256 0.999998 0.000000 TMEM-P

KNOWN ISSUES

  • hd=512: SMEM overflow (344KB > 232KB). sQ(128KB) + sK(128KB) + sV(64KB) too large. Needs SMEM tiling or buffer overlap.
  • O rescale (kt>0): Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed.
  • Kernel always outputs un-normalized O + LSE. No in-kernel normalization (eliminates TMEM round-trip error). External normalization: O_norm = O_unnorm / row_sum.

Architecture

6-Warp Layout

Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)

Kernel Output

The kernel outputs un-normalized O + LSE via epilogue_tma_store:

  • O_unnorm = sum(P * V) where P = exp(S * scale - row_max)
  • LSE = ln(row_sum) + row_max * ln(2)
  • External normalization: O_norm = O_unnorm / row_sum
  • For D5 merge: use exp(LSE) directly in the merge formula

TMEM Layout

Col 0-31:   S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95:  P (64 FP32 via register bridge, BF16 view)
Col 128+:   O (PV acc, 64+ FP32)

P Staging Paths

TMEM-P (hd≤64, also works at hd=128/256):

  • P stored to TMEM via register bridge (FP32 backing + BF16 view)
  • PV MMA reads P from TMEM via tOrP0
  • Works because QK C-fragment and PV A-fragment TMEM layouts agree at tested head dims

SMEM-P (hd>64):

  • P written to SMEM via coordinate-indexed store
  • Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates
  • Maps to sP's subtile layout: sP[(m_coord, k_sub), 0, (k_g1, k_g2)]
  • PV MMA reads P from SMEM via tCrP = pv_mma.make_fragment_A(sP)
  • SMEM-P uses OperandSource.SMEM for PV MMA

Key Configuration

head_dim:     constructor arg (64, 128, 256, 512)
pv_n_tile:    min(head_dim, 256)  # tcgen05 MMA max N=256
n_pv_tiles:   head_dim // pv_n_tile
kv_stage:     1 if head_dim > 128 else 2  # Reduce SMEM at large hd
use_smem_p:   head_dim > 64  # SMEM-P for hd>64
qk_mma_tiler: (128, 128, head_dim)  # K-dim = head_dim (NOT hardcoded!)

Critical Bug Fix: qk_mma_tiler K-dim (2026-05-24)

ROOT CAUSE of hd>64 failure: qk_mma_tiler K-dim was hardcoded to qk_ik * 4 = 64 instead of head_dim.

This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd>64. The QK dot products were half the correct length, producing wrong attention scores.

Fix: self.qk_mma_tiler = (128, 128, self.head_dim) — one line change.

Impact: hd=128 went from cos 0.78 to 0.999997. hd=256 went from broken to 0.999998.

LESSON: The MMA tiler's K dimension must match the actual GEMM K dimension (head_dim), not the MMA instruction's K sub-tile size.


SMEM Budget at Various hd

hd sQ sK (kv_stage=1) sV (kv_stage=1) sP (SMEM-P) sC Total Limit Status
64 32KB 32KB 32KB 32KB 32KB 160KB 232KB
128 32KB 32KB 32KB 32KB 32KB 160KB 232KB
256 64KB 64KB 64KB 0* 32KB 224KB 232KB
512 128KB 128KB 64KB 0* 32KB 352KB 232KB

*TMEM-P path: sP allocation skipped (const_expr conditional)


D1.5: Correction Epilogue (TMEM Round-Trip Error)

Issue: Hand-constructed Ld32x32bOp/St32x32bOp atoms don't preserve the C-fragment layout during TMEM round-trips (load→modify→store). Causes ~3% error per round-trip.

Current workaround: Kernel outputs un-normalized O + LSE. No in-kernel normalization needed. External normalization is exact.

Proper fix (future): Use CUTLASS epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern with paired atoms. One-way trip: TMEM → registers (normalize) → SMEM → GMEM.

Priority: MEDIUM. Not a correctness blocker (external normalization is exact). Would enable in-kernel normalization for D5c/D5d.


Build Order (Remaining)

D1.4 — hd=512 SMEM Budget CURRENT

hd=512 needs sQ(128KB) + sK(128KB) + sV(64KB) = 320KB. Must reduce to fit 232KB.

Options:

  1. Tile Q along head_dim: Process Q in chunks of 256. Two Q sub-tiles per kernel.
  2. SMEM buffer overlap: sQ and sK/sV used at different times. After Q is consumed by MMA, reuse sQ's SMEM for K/V.
  3. Split the GEMM K dimension: Process K in sub-tiles (K=256 then K=256-512). Each sub-tile fits SMEM.

D2 — Multi-Query Grid with Head Packing

  • Grid changes from (1, 1, 1) to (num_q_blocks, 1, batch)
  • DSV4 is MQA: all 128 query heads share same K/V
  • Head axis folded into M dimension: M_tile = 128 covers M = T * n_h rows

D3 — SWA Sequence Length Mask

  • Add swa_lens: [batch] int32 kernel input
  • Mask SWA-branch logits to -inf where swa_idx >= swa_lens[b]

D4 — Causal Mask on SWA Branch

  • Add is_causal: bool constructor flag
  • Apply swa_idx > q_pos masking to -inf in SWA pass

D5 — SWA + Sink Merge

  • D5a : Kernel outputs un-normalized O + LSE
  • D5b : Python merge works (cos 0.961 at hd=64)
  • D5c: Fuse two passes into one kernel launch
  • D5d: Fuse sink merge into kernel epilogue