From e98f5e4f9e574289daca29c2854d2fa5fea5dd4d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 05:52:03 +0000 Subject: [PATCH] Add STAGE_D.md: step-by-step runbook and todo list for D1-D5 --- STAGE_D.md | 196 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 STAGE_D.md diff --git a/STAGE_D.md b/STAGE_D.md new file mode 100644 index 00000000..70ace41f --- /dev/null +++ b/STAGE_D.md @@ -0,0 +1,196 @@ +# Stage D — Parameterized FMHA for DSV4 + +## ⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING + +### The Workflow (DO NOT SKIP STEPS) + +1. **Edit code in** `~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py` — this is the ONLY file for the FMHA kernel. +2. **Commit and push:** + ```bash + cd ~/dev/nvfp4-megamoe-kernel + git add -A && git commit -m "description" && git push origin master + ``` +3. **Pull on B200:** + ```bash + sshpass -p '6)Jr)B@dcX[mN?dx' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ + "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" + ``` +4. **Test on B200:** + ```bash + sshpass -p '6)Jr)B@dcX[mN?dx' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ + "cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'" + ``` +5. **Regression check:** After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert. + +### The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION) + +- **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time. +- **NEVER delete or modify the test files in `tests/unit/`.** They are the regression oracle. +- **NEVER touch drivers, kernels, firmware, or system packages on the B200.** +- **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Even compile-time constants. Define all variables unconditionally before any branching. +- **Always test at hd=64 FIRST.** If the proven path (TMEM-P) regresses, nothing else matters. +- **`p_cols_fp32` uses `pv_mma_tiler[2]` (K-dim), NOT `pv_mma_tiler[1]` (N-dim).** We got this wrong twice. +- **PV A-operand major mode is `OperandMajorMode.K` for TMEM-P.** Not `a_major` from Q. +- **`tOrP0` uses 3-dim indexing `(None, None, kb)`, NOT 4-dim `(None, None, kb, 0)`.** The 4th mode was already sliced away by `tOrP_base[(None,None,None,0)]`. +- **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN. + +--- + +## What We Have Now (Starting Point) + +**File:** `dsv4/kernels/attention/fmha.py` +**Class:** `FmhaKernel` +**State:** Exact copy of Stage C test. Works at hd=64 only. cos 0.972537 at n=128. + +**What it does:** +- 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA) +- QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM +- O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip +- Epilogue: TMEM → SMEM → GMEM via TMA store + +**Hardcoded constant that must die:** `HEAD_DIM = 64` on line 18, used in 7 places. + +--- + +## The Problem at hd>64 + +At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works. + +At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts **disagree** — different threads own different columns. The register bridge can't write P in a layout that PV can read. + +**The fix: SMEM-P path.** P goes through SMEM instead of TMEM: +1. Softmax computes P in registers (QK C-fragment partition) +2. Write P to SMEM using the `p_smem_s` layout (PV A-operand SMEM layout) +3. MMA warp reads P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)` +4. PV GEMM uses `tcgen05.OperandSource.SMEM` instead of `OperandSource.TMEM` + +**The SMEM rendezvous:** SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address. + +**The missing piece (the D1 work):** The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a `TiledCopy` that partitions threads by QK's C-fragment and targets the P SMEM layout. + +```python +# The correct approach: +store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32) +tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) # NOT pv_mma! +# This gives threads partitioned by QK C-fragment, writing to the P SMEM layout +``` + +Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM. + +**Alternative (from the FlashMLA SM100 reference):** FlashMLA keeps P in TMEM at hd≤128 using `St32x32bOp` with QK C-fragment composition (same as our Stage C). At hd>128, they'd need the SMEM path. They don't support hd>128 yet. + +--- + +## Stage D TODO List + +### D1.0 — Replace `HEAD_DIM = 64` with constructor parameter ✅ (next step) + +- [ ] Add `head_dim` to `FmhaKernel.__init__()` +- [ ] Replace all 7 uses of `HEAD_DIM` with `self.head_dim` +- [ ] Keep `use_smem_p=False` as default (TMEM-P path) +- [ ] **Test:** hd=64, n=128 → cos 0.972537 (must match exactly) +- [ ] **Test:** hd=64, n=256 → cos 0.792775 (must match exactly) +- [ ] **DO NOT add SMEM-P code yet.** Just parameterize. Test first. + +The 7 places `HEAD_DIM` is used: +1. `__init__`: `1.0 / math.sqrt(HEAD_DIM)` → `1.0 / math.sqrt(head_dim)` +2. `_setup`: `self.pv_mma_tiler = (128, HEAD_DIM, ...)` → `(128, self.head_dim, ...)` +3. `_setup`: `self.cta_tile_shape_mnk = (..., HEAD_DIM, ...)` → `(..., self.head_dim, ...)` +4. `__call__`: `cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k))` +5. `__call__`: `pv_mma = ... (128, HEAD_DIM) ...` +6. softmax: `n_corr_tiles = HEAD_DIM // corr_tile_size` +7. (Check for any others: `grep HEAD_DIM dsv4/kernels/attention/fmha.py`) + +### D1.1 — Add SMEM-P path behind `use_smem_p` flag + +- [ ] Add `use_smem_p` to `__init__` (default: `head_dim > 64`) +- [ ] In `_setup`: conditional TMEM layout (TMEM-P has `tmem_p0_offset=32`, SMEM-P has `tmem_p0_offset=-1` and `tmem_o0_offset=0`) +- [ ] In `_setup`: allocate `p_smem_s` for SMEM-P (PV A-operand SMEM layout) +- [ ] In `__call__`: `pv_mma` uses `OperandSource.SMEM` when `use_smem_p`, `OperandSource.TMEM` otherwise +- [ ] In `__call__`: PV A-operand major mode is `a_major` for SMEM-P, `OperandMajorMode.K` for TMEM-P +- [ ] **CuTeDSL scoping:** Define ALL variables unconditionally before any `if use_smem_p` blocks. Both `tOrP0` (TMEM) and `tCrP` (SMEM) must exist before the warp-branching starts. +- [ ] **Test:** hd=64, n=128, `use_smem_p=False` → cos 0.972537 (regression) + +### D1.2 — Implement register→SMEM copy for P (the hard part) + +- [ ] Build `tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)` — QK MMA partitions threads +- [ ] Partition `sP` with `tiled_p_copy` as destination +- [ ] In softmax warps: after computing P in registers, write to SMEM via `tiled_p_copy` +- [ ] Add `p_smem_ready_bar` barrier: softmax arrives after write, MMA waits before PV GEMM +- [ ] In MMA warp: read P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)` +- [ ] **Test:** hd=64, n=128, `use_smem_p=True` → compare against TMEM-P result (should be close) +- [ ] **Test:** hd=128, n=128 → test against FP32 oracle +- [ ] **Test:** hd=256, n=128 → test against FP32 oracle +- [ ] **Test:** hd=512, n=128 → test against FP32 oracle (DSV4's real value) + +### D1.3 — Multi-PV-tile for hd>256 + +- [ ] When `head_dim > 256`, the MMA instruction can only process 256 columns at a time +- [ ] `pv_n_tile = min(head_dim, 256)`, `n_pv_tiles = head_dim // pv_n_tile` +- [ ] Multiple PV GEMM passes per KV tile, accumulating O +- [ ] V must be re-constructed with `v_n = pv_n_tile` per pass +- [ ] This may require multiple kernel launches at Python level (or a loop inside the kernel) +- [ ] **Test:** hd=512, n=128 → correct output against FP32 oracle + +### D1.4 — Cleanup and regression + +- [ ] Remove `HEAD_DIM = 64` constant entirely +- [ ] Add `head_dim` as first constructor arg (no default — always explicit) +- [ ] Default `use_smem_p=None` → auto-detect from `head_dim > 64` +- [ ] Test matrix: hd ∈ {64, 128, 256, 512} × n ∈ {128, 256} +- [ ] Update README status table: D1 → ✅ COMPLETE +- [ ] Cross off D1.0–D1.4 in this file + +--- + +## D2 — Multi-query grid with head packing (after D1) + +- [ ] Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)` +- [ ] DSV4 is MQA: all 128 query heads share same K/V +- [ ] Head axis folded into M dimension of Q tile +- [ ] **Test:** batch=4, T=64, n_h=128, num_kv_heads=1 + +## D3 — SWA sequence length mask + +- [ ] Add `swa_lens: [batch] int32` kernel input +- [ ] Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]` +- [ ] **Test:** varying SWA fill levels + +## D4 — Causal mask on SWA branch + +- [ ] Add `is_causal: bool` constructor flag +- [ ] Apply `swa_idx > q_pos` masking in SWA pass +- [ ] Main path has NO mask (indexer enforces causality upstream) + +## D5 — SWA + sink merge + +- [ ] D5a: Emit un-normalized `o` + `lse` instead of normalized `o` (keep normalize as flag) +- [ ] D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python +- [ ] D5c: Fuse two passes into one kernel launch (Q stays in SMEM) +- [ ] D5d: Fuse sink merge into kernel epilogue + +--- + +## Key References + +| What | Where | +|------|-------| +| Working FMHA kernel (hd=64) | `dsv4/kernels/attention/fmha.py` — `FmhaKernel` | +| Stage C test (oracle) | `tests/unit/test_fmha_v3_stage_c.py` — `FmhaV3StageCMulti` | +| Stage A+B test | `tests/unit/test_fmha_v3.py` | +| FlashMLA SM100 reference | `/root/dsv4-nvfp4-workspace/vllm/.deps/flashmla-src/csrc/cutlass/examples/python/CuTeDSL/blackwell/fmha.py` (on B200) | +| CUTLASS FMHA reference | `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (on B200) | +| Sink merge spec | `dsv4/ops/decode_sparse.py` | +| SWA decode | `dsv4/ops/decode_swa.py` | +| Attention reference | `dsv4/reference/attention.py` | +| CSA attention reference | `dsv4/reference/csa_attention.py` | + +## B200 Environment + +``` +Server: root@45.76.247.107 (password: 6)Jr)B@dcX[mN?dx) +Kernel repo: /root/dsv4-nvfp4-workspace/kernel +Venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate +PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel +Test command: python3 tests/unit/test_fmha_v3_stage_c.py +```