10 KiB
Stage D — Parameterized FMHA for DSV4
⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
The Workflow (DO NOT SKIP STEPS)
- Edit code in
~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py— this is the ONLY file for the FMHA kernel. - Commit and push:
cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master - Pull on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" - Test on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'" - Regression check: After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert.
The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION)
- NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
- NEVER delete or modify the test files in
tests/unit/. They are the regression oracle. - NEVER touch drivers, kernels, firmware, or system packages on the B200.
- CuTeDSL variables defined in
ifblocks are NOT visible in otherifblocks. Even compile-time constants. Define all variables unconditionally before any branching. - Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
p_cols_fp32usespv_mma_tiler[2](K-dim), NOTpv_mma_tiler[1](N-dim). We got this wrong twice.- PV A-operand major mode is
OperandMajorMode.Kfor TMEM-P. Nota_majorfrom Q. tOrP0uses 3-dim indexing(None, None, kb), NOT 4-dim(None, None, kb, 0). The 4th mode was already sliced away bytOrP_base[(None,None,None,0)].- After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.
What We Have Now (Starting Point)
File: dsv4/kernels/attention/fmha.py
Class: FmhaKernel
State: Exact copy of Stage C test. Works at hd=64 only. cos 0.972537 at n=128.
What it does:
- 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA)
- QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM
- O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip
- Epilogue: TMEM → SMEM → GMEM via TMA store
Hardcoded constant that must die: HEAD_DIM = 64 on line 18, used in 7 places.
The Problem at hd>64
At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works.
At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts disagree — different threads own different columns. The register bridge can't write P in a layout that PV can read.
The fix: SMEM-P path. P goes through SMEM instead of TMEM:
- Softmax computes P in registers (QK C-fragment partition)
- Write P to SMEM using the
p_smem_slayout (PV A-operand SMEM layout) - MMA warp reads P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - PV GEMM uses
tcgen05.OperandSource.SMEMinstead ofOperandSource.TMEM
The SMEM rendezvous: SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address.
The missing piece (the D1 work): The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a TiledCopy that partitions threads by QK's C-fragment and targets the P SMEM layout.
# The correct approach:
store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) # NOT pv_mma!
# This gives threads partitioned by QK C-fragment, writing to the P SMEM layout
Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM.
Alternative (from the FlashMLA SM100 reference): FlashMLA keeps P in TMEM at hd≤128 using St32x32bOp with QK C-fragment composition (same as our Stage C). At hd>128, they'd need the SMEM path. They don't support hd>128 yet.
Stage D TODO List
D1.0 — Replace HEAD_DIM = 64 with constructor parameter ✅ (next step)
- Add
head_dimtoFmhaKernel.__init__() - Replace all 7 uses of
HEAD_DIMwithself.head_dim - Keep
use_smem_p=Falseas default (TMEM-P path) - Test: hd=64, n=128 → cos 0.972537 (must match exactly)
- Test: hd=64, n=256 → cos 0.792775 (must match exactly)
- DO NOT add SMEM-P code yet. Just parameterize. Test first.
The 7 places HEAD_DIM is used:
__init__:1.0 / math.sqrt(HEAD_DIM)→1.0 / math.sqrt(head_dim)_setup:self.pv_mma_tiler = (128, HEAD_DIM, ...)→(128, self.head_dim, ...)_setup:self.cta_tile_shape_mnk = (..., HEAD_DIM, ...)→(..., self.head_dim, ...)__call__:cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k))__call__:pv_mma = ... (128, HEAD_DIM) ...- softmax:
n_corr_tiles = HEAD_DIM // corr_tile_size - (Check for any others:
grep HEAD_DIM dsv4/kernels/attention/fmha.py)
D1.1 — Add SMEM-P path behind use_smem_p flag
- Add
use_smem_pto__init__(default:head_dim > 64) - In
_setup: conditional TMEM layout (TMEM-P hastmem_p0_offset=32, SMEM-P hastmem_p0_offset=-1andtmem_o0_offset=0) - In
_setup: allocatep_smem_sfor SMEM-P (PV A-operand SMEM layout) - In
__call__:pv_mmausesOperandSource.SMEMwhenuse_smem_p,OperandSource.TMEMotherwise - In
__call__: PV A-operand major mode isa_majorfor SMEM-P,OperandMajorMode.Kfor TMEM-P - CuTeDSL scoping: Define ALL variables unconditionally before any
if use_smem_pblocks. BothtOrP0(TMEM) andtCrP(SMEM) must exist before the warp-branching starts. - Test: hd=64, n=128,
use_smem_p=False→ cos 0.972537 (regression)
D1.2 — Implement register→SMEM copy for P (the hard part)
- Build
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)— QK MMA partitions threads - Partition
sPwithtiled_p_copyas destination - In softmax warps: after computing P in registers, write to SMEM via
tiled_p_copy - Add
p_smem_ready_barbarrier: softmax arrives after write, MMA waits before PV GEMM - In MMA warp: read P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - Test: hd=64, n=128,
use_smem_p=True→ compare against TMEM-P result (should be close) - Test: hd=128, n=128 → test against FP32 oracle
- Test: hd=256, n=128 → test against FP32 oracle
- Test: hd=512, n=128 → test against FP32 oracle (DSV4's real value)
D1.3 — Multi-PV-tile for hd>256
- When
head_dim > 256, the MMA instruction can only process 256 columns at a time pv_n_tile = min(head_dim, 256),n_pv_tiles = head_dim // pv_n_tile- Multiple PV GEMM passes per KV tile, accumulating O
- V must be re-constructed with
v_n = pv_n_tileper pass - This may require multiple kernel launches at Python level (or a loop inside the kernel)
- Test: hd=512, n=128 → correct output against FP32 oracle
D1.4 — Cleanup and regression
- Remove
HEAD_DIM = 64constant entirely - Add
head_dimas first constructor arg (no default — always explicit) - Default
use_smem_p=None→ auto-detect fromhead_dim > 64 - Test matrix: hd ∈ {64, 128, 256, 512} × n ∈ {128, 256}
- Update README status table: D1 → ✅ COMPLETE
- Cross off D1.0–D1.4 in this file
D2 — Multi-query grid with head packing (after D1)
- Grid changes from
(1, 1, 1)to(num_q_blocks, 1, batch) - DSV4 is MQA: all 128 query heads share same K/V
- Head axis folded into M dimension of Q tile
- Test: batch=4, T=64, n_h=128, num_kv_heads=1
D3 — SWA sequence length mask
- Add
swa_lens: [batch] int32kernel input - Mask SWA-branch logits to
-infwhereswa_idx >= swa_lens[b] - Test: varying SWA fill levels
D4 — Causal mask on SWA branch
- Add
is_causal: boolconstructor flag - Apply
swa_idx > q_posmasking in SWA pass - Main path has NO mask (indexer enforces causality upstream)
D5 — SWA + sink merge
- D5a: Emit un-normalized
o+lseinstead of normalizedo(keep normalize as flag) - D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python
- D5c: Fuse two passes into one kernel launch (Q stays in SMEM)
- D5d: Fuse sink merge into kernel epilogue
Key References
| What | Where |
|---|---|
| Working FMHA kernel (hd=64) | dsv4/kernels/attention/fmha.py — FmhaKernel |
| Stage C test (oracle) | tests/unit/test_fmha_v3_stage_c.py — FmhaV3StageCMulti |
| Stage A+B test | tests/unit/test_fmha_v3.py |
| FlashMLA SM100 reference | /root/dsv4-nvfp4-workspace/vllm/.deps/flashmla-src/csrc/cutlass/examples/python/CuTeDSL/blackwell/fmha.py (on B200) |
| CUTLASS FMHA reference | /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py (on B200) |
| Sink merge spec | dsv4/ops/decode_sparse.py |
| SWA decode | dsv4/ops/decode_swa.py |
| Attention reference | dsv4/reference/attention.py |
| CSA attention reference | dsv4/reference/csa_attention.py |
B200 Environment
Server: root@45.76.247.107 (password: <B200_PASSWORD>)
Kernel repo: /root/dsv4-nvfp4-workspace/kernel
Venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
Test command: python3 tests/unit/test_fmha_v3_stage_c.py