Files
nvfp4-megamoe-kernel/STAGE_D.md

10 KiB
Raw Blame History

Stage D — Parameterized FMHA for DSV4

⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING

The Workflow (DO NOT SKIP STEPS)

  1. Edit code in ~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py — this is the ONLY file for the FMHA kernel.
  2. Commit and push:
    cd ~/dev/nvfp4-megamoe-kernel
    git add -A && git commit -m "description" && git push origin master
    
  3. Pull on B200:
    sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
      "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master"
    
  4. Test on B200:
    sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
      "cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'"
    
  5. Regression check: After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert.

The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION)

  • NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
  • NEVER delete or modify the test files in tests/unit/. They are the regression oracle.
  • NEVER touch drivers, kernels, firmware, or system packages on the B200.
  • CuTeDSL variables defined in if blocks are NOT visible in other if blocks. Even compile-time constants. Define all variables unconditionally before any branching.
  • Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
  • p_cols_fp32 uses pv_mma_tiler[2] (K-dim), NOT pv_mma_tiler[1] (N-dim). We got this wrong twice.
  • PV A-operand major mode is OperandMajorMode.K for TMEM-P. Not a_major from Q.
  • tOrP0 uses 3-dim indexing (None, None, kb), NOT 4-dim (None, None, kb, 0). The 4th mode was already sliced away by tOrP_base[(None,None,None,0)].
  • After every P store to TMEM, call cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.

What We Have Now (Starting Point)

File: dsv4/kernels/attention/fmha.py Class: FmhaKernel State: Exact copy of Stage C test. Works at hd=64 only. cos 0.972537 at n=128.

What it does:

  • 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA)
  • QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM
  • O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip
  • Epilogue: TMEM → SMEM → GMEM via TMA store

Hardcoded constant that must die: HEAD_DIM = 64 on line 18, used in 7 places.


The Problem at hd>64

At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works.

At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts disagree — different threads own different columns. The register bridge can't write P in a layout that PV can read.

The fix: SMEM-P path. P goes through SMEM instead of TMEM:

  1. Softmax computes P in registers (QK C-fragment partition)
  2. Write P to SMEM using the p_smem_s layout (PV A-operand SMEM layout)
  3. MMA warp reads P from SMEM via tCrP = pv_mma.make_fragment_A(sP)
  4. PV GEMM uses tcgen05.OperandSource.SMEM instead of OperandSource.TMEM

The SMEM rendezvous: SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address.

The missing piece (the D1 work): The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a TiledCopy that partitions threads by QK's C-fragment and targets the P SMEM layout.

# The correct approach:
store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)  # NOT pv_mma!
# This gives threads partitioned by QK C-fragment, writing to the P SMEM layout

Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM.

Alternative (from the FlashMLA SM100 reference): FlashMLA keeps P in TMEM at hd≤128 using St32x32bOp with QK C-fragment composition (same as our Stage C). At hd>128, they'd need the SMEM path. They don't support hd>128 yet.


Stage D TODO List

D1.0 — Replace HEAD_DIM = 64 with constructor parameter (next step)

  • Add head_dim to FmhaKernel.__init__()
  • Replace all 7 uses of HEAD_DIM with self.head_dim
  • Keep use_smem_p=False as default (TMEM-P path)
  • Test: hd=64, n=128 → cos 0.972537 (must match exactly)
  • Test: hd=64, n=256 → cos 0.792775 (must match exactly)
  • DO NOT add SMEM-P code yet. Just parameterize. Test first.

The 7 places HEAD_DIM is used:

  1. __init__: 1.0 / math.sqrt(HEAD_DIM)1.0 / math.sqrt(head_dim)
  2. _setup: self.pv_mma_tiler = (128, HEAD_DIM, ...)(128, self.head_dim, ...)
  3. _setup: self.cta_tile_shape_mnk = (..., HEAD_DIM, ...)(..., self.head_dim, ...)
  4. __call__: cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k))
  5. __call__: pv_mma = ... (128, HEAD_DIM) ...
  6. softmax: n_corr_tiles = HEAD_DIM // corr_tile_size
  7. (Check for any others: grep HEAD_DIM dsv4/kernels/attention/fmha.py)

D1.1 — Add SMEM-P path behind use_smem_p flag

  • Add use_smem_p to __init__ (default: head_dim > 64)
  • In _setup: conditional TMEM layout (TMEM-P has tmem_p0_offset=32, SMEM-P has tmem_p0_offset=-1 and tmem_o0_offset=0)
  • In _setup: allocate p_smem_s for SMEM-P (PV A-operand SMEM layout)
  • In __call__: pv_mma uses OperandSource.SMEM when use_smem_p, OperandSource.TMEM otherwise
  • In __call__: PV A-operand major mode is a_major for SMEM-P, OperandMajorMode.K for TMEM-P
  • CuTeDSL scoping: Define ALL variables unconditionally before any if use_smem_p blocks. Both tOrP0 (TMEM) and tCrP (SMEM) must exist before the warp-branching starts.
  • Test: hd=64, n=128, use_smem_p=False → cos 0.972537 (regression)

D1.2 — Implement register→SMEM copy for P (the hard part)

  • Build tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) — QK MMA partitions threads
  • Partition sP with tiled_p_copy as destination
  • In softmax warps: after computing P in registers, write to SMEM via tiled_p_copy
  • Add p_smem_ready_bar barrier: softmax arrives after write, MMA waits before PV GEMM
  • In MMA warp: read P from SMEM via tCrP = pv_mma.make_fragment_A(sP)
  • Test: hd=64, n=128, use_smem_p=True → compare against TMEM-P result (should be close)
  • Test: hd=128, n=128 → test against FP32 oracle
  • Test: hd=256, n=128 → test against FP32 oracle
  • Test: hd=512, n=128 → test against FP32 oracle (DSV4's real value)

D1.3 — Multi-PV-tile for hd>256

  • When head_dim > 256, the MMA instruction can only process 256 columns at a time
  • pv_n_tile = min(head_dim, 256), n_pv_tiles = head_dim // pv_n_tile
  • Multiple PV GEMM passes per KV tile, accumulating O
  • V must be re-constructed with v_n = pv_n_tile per pass
  • This may require multiple kernel launches at Python level (or a loop inside the kernel)
  • Test: hd=512, n=128 → correct output against FP32 oracle

D1.4 — Cleanup and regression

  • Remove HEAD_DIM = 64 constant entirely
  • Add head_dim as first constructor arg (no default — always explicit)
  • Default use_smem_p=None → auto-detect from head_dim > 64
  • Test matrix: hd ∈ {64, 128, 256, 512} × n ∈ {128, 256}
  • Update README status table: D1 → COMPLETE
  • Cross off D1.0D1.4 in this file

D2 — Multi-query grid with head packing (after D1)

  • Grid changes from (1, 1, 1) to (num_q_blocks, 1, batch)
  • DSV4 is MQA: all 128 query heads share same K/V
  • Head axis folded into M dimension of Q tile
  • Test: batch=4, T=64, n_h=128, num_kv_heads=1

D3 — SWA sequence length mask

  • Add swa_lens: [batch] int32 kernel input
  • Mask SWA-branch logits to -inf where swa_idx >= swa_lens[b]
  • Test: varying SWA fill levels

D4 — Causal mask on SWA branch

  • Add is_causal: bool constructor flag
  • Apply swa_idx > q_pos masking in SWA pass
  • Main path has NO mask (indexer enforces causality upstream)

D5 — SWA + sink merge

  • D5a: Emit un-normalized o + lse instead of normalized o (keep normalize as flag)
  • D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python
  • D5c: Fuse two passes into one kernel launch (Q stays in SMEM)
  • D5d: Fuse sink merge into kernel epilogue

Key References

What Where
Working FMHA kernel (hd=64) dsv4/kernels/attention/fmha.pyFmhaKernel
Stage C test (oracle) tests/unit/test_fmha_v3_stage_c.pyFmhaV3StageCMulti
Stage A+B test tests/unit/test_fmha_v3.py
FlashMLA SM100 reference /root/dsv4-nvfp4-workspace/vllm/.deps/flashmla-src/csrc/cutlass/examples/python/CuTeDSL/blackwell/fmha.py (on B200)
CUTLASS FMHA reference /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py (on B200)
Sink merge spec dsv4/ops/decode_sparse.py
SWA decode dsv4/ops/decode_swa.py
Attention reference dsv4/reference/attention.py
CSA attention reference dsv4/reference/csa_attention.py

B200 Environment

Server: root@45.76.247.107 (password: <B200_PASSWORD>)
Kernel repo: /root/dsv4-nvfp4-workspace/kernel
Venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
Test command: python3 tests/unit/test_fmha_v3_stage_c.py