Add STAGE_D.md: step-by-step runbook and todo list for D1-D5
This commit is contained in:
196
STAGE_D.md
Normal file
196
STAGE_D.md
Normal file
@@ -0,0 +1,196 @@
|
||||
# Stage D — Parameterized FMHA for DSV4
|
||||
|
||||
## ⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
|
||||
|
||||
### The Workflow (DO NOT SKIP STEPS)
|
||||
|
||||
1. **Edit code in** `~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py` — this is the ONLY file for the FMHA kernel.
|
||||
2. **Commit and push:**
|
||||
```bash
|
||||
cd ~/dev/nvfp4-megamoe-kernel
|
||||
git add -A && git commit -m "description" && git push origin master
|
||||
```
|
||||
3. **Pull on B200:**
|
||||
```bash
|
||||
sshpass -p '6)Jr)B@dcX[mN?dx' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
|
||||
"cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master"
|
||||
```
|
||||
4. **Test on B200:**
|
||||
```bash
|
||||
sshpass -p '6)Jr)B@dcX[mN?dx' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
|
||||
"cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'"
|
||||
```
|
||||
5. **Regression check:** After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert.
|
||||
|
||||
### The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION)
|
||||
|
||||
- **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time.
|
||||
- **NEVER delete or modify the test files in `tests/unit/`.** They are the regression oracle.
|
||||
- **NEVER touch drivers, kernels, firmware, or system packages on the B200.**
|
||||
- **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Even compile-time constants. Define all variables unconditionally before any branching.
|
||||
- **Always test at hd=64 FIRST.** If the proven path (TMEM-P) regresses, nothing else matters.
|
||||
- **`p_cols_fp32` uses `pv_mma_tiler[2]` (K-dim), NOT `pv_mma_tiler[1]` (N-dim).** We got this wrong twice.
|
||||
- **PV A-operand major mode is `OperandMajorMode.K` for TMEM-P.** Not `a_major` from Q.
|
||||
- **`tOrP0` uses 3-dim indexing `(None, None, kb)`, NOT 4-dim `(None, None, kb, 0)`.** The 4th mode was already sliced away by `tOrP_base[(None,None,None,0)]`.
|
||||
- **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN.
|
||||
|
||||
---
|
||||
|
||||
## What We Have Now (Starting Point)
|
||||
|
||||
**File:** `dsv4/kernels/attention/fmha.py`
|
||||
**Class:** `FmhaKernel`
|
||||
**State:** Exact copy of Stage C test. Works at hd=64 only. cos 0.972537 at n=128.
|
||||
|
||||
**What it does:**
|
||||
- 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA)
|
||||
- QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM
|
||||
- O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip
|
||||
- Epilogue: TMEM → SMEM → GMEM via TMA store
|
||||
|
||||
**Hardcoded constant that must die:** `HEAD_DIM = 64` on line 18, used in 7 places.
|
||||
|
||||
---
|
||||
|
||||
## The Problem at hd>64
|
||||
|
||||
At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works.
|
||||
|
||||
At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts **disagree** — different threads own different columns. The register bridge can't write P in a layout that PV can read.
|
||||
|
||||
**The fix: SMEM-P path.** P goes through SMEM instead of TMEM:
|
||||
1. Softmax computes P in registers (QK C-fragment partition)
|
||||
2. Write P to SMEM using the `p_smem_s` layout (PV A-operand SMEM layout)
|
||||
3. MMA warp reads P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)`
|
||||
4. PV GEMM uses `tcgen05.OperandSource.SMEM` instead of `OperandSource.TMEM`
|
||||
|
||||
**The SMEM rendezvous:** SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address.
|
||||
|
||||
**The missing piece (the D1 work):** The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a `TiledCopy` that partitions threads by QK's C-fragment and targets the P SMEM layout.
|
||||
|
||||
```python
|
||||
# The correct approach:
|
||||
store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||||
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) # NOT pv_mma!
|
||||
# This gives threads partitioned by QK C-fragment, writing to the P SMEM layout
|
||||
```
|
||||
|
||||
Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM.
|
||||
|
||||
**Alternative (from the FlashMLA SM100 reference):** FlashMLA keeps P in TMEM at hd≤128 using `St32x32bOp` with QK C-fragment composition (same as our Stage C). At hd>128, they'd need the SMEM path. They don't support hd>128 yet.
|
||||
|
||||
---
|
||||
|
||||
## Stage D TODO List
|
||||
|
||||
### D1.0 — Replace `HEAD_DIM = 64` with constructor parameter ✅ (next step)
|
||||
|
||||
- [ ] Add `head_dim` to `FmhaKernel.__init__()`
|
||||
- [ ] Replace all 7 uses of `HEAD_DIM` with `self.head_dim`
|
||||
- [ ] Keep `use_smem_p=False` as default (TMEM-P path)
|
||||
- [ ] **Test:** hd=64, n=128 → cos 0.972537 (must match exactly)
|
||||
- [ ] **Test:** hd=64, n=256 → cos 0.792775 (must match exactly)
|
||||
- [ ] **DO NOT add SMEM-P code yet.** Just parameterize. Test first.
|
||||
|
||||
The 7 places `HEAD_DIM` is used:
|
||||
1. `__init__`: `1.0 / math.sqrt(HEAD_DIM)` → `1.0 / math.sqrt(head_dim)`
|
||||
2. `_setup`: `self.pv_mma_tiler = (128, HEAD_DIM, ...)` → `(128, self.head_dim, ...)`
|
||||
3. `_setup`: `self.cta_tile_shape_mnk = (..., HEAD_DIM, ...)` → `(..., self.head_dim, ...)`
|
||||
4. `__call__`: `cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k))`
|
||||
5. `__call__`: `pv_mma = ... (128, HEAD_DIM) ...`
|
||||
6. softmax: `n_corr_tiles = HEAD_DIM // corr_tile_size`
|
||||
7. (Check for any others: `grep HEAD_DIM dsv4/kernels/attention/fmha.py`)
|
||||
|
||||
### D1.1 — Add SMEM-P path behind `use_smem_p` flag
|
||||
|
||||
- [ ] Add `use_smem_p` to `__init__` (default: `head_dim > 64`)
|
||||
- [ ] In `_setup`: conditional TMEM layout (TMEM-P has `tmem_p0_offset=32`, SMEM-P has `tmem_p0_offset=-1` and `tmem_o0_offset=0`)
|
||||
- [ ] In `_setup`: allocate `p_smem_s` for SMEM-P (PV A-operand SMEM layout)
|
||||
- [ ] In `__call__`: `pv_mma` uses `OperandSource.SMEM` when `use_smem_p`, `OperandSource.TMEM` otherwise
|
||||
- [ ] In `__call__`: PV A-operand major mode is `a_major` for SMEM-P, `OperandMajorMode.K` for TMEM-P
|
||||
- [ ] **CuTeDSL scoping:** Define ALL variables unconditionally before any `if use_smem_p` blocks. Both `tOrP0` (TMEM) and `tCrP` (SMEM) must exist before the warp-branching starts.
|
||||
- [ ] **Test:** hd=64, n=128, `use_smem_p=False` → cos 0.972537 (regression)
|
||||
|
||||
### D1.2 — Implement register→SMEM copy for P (the hard part)
|
||||
|
||||
- [ ] Build `tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)` — QK MMA partitions threads
|
||||
- [ ] Partition `sP` with `tiled_p_copy` as destination
|
||||
- [ ] In softmax warps: after computing P in registers, write to SMEM via `tiled_p_copy`
|
||||
- [ ] Add `p_smem_ready_bar` barrier: softmax arrives after write, MMA waits before PV GEMM
|
||||
- [ ] In MMA warp: read P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)`
|
||||
- [ ] **Test:** hd=64, n=128, `use_smem_p=True` → compare against TMEM-P result (should be close)
|
||||
- [ ] **Test:** hd=128, n=128 → test against FP32 oracle
|
||||
- [ ] **Test:** hd=256, n=128 → test against FP32 oracle
|
||||
- [ ] **Test:** hd=512, n=128 → test against FP32 oracle (DSV4's real value)
|
||||
|
||||
### D1.3 — Multi-PV-tile for hd>256
|
||||
|
||||
- [ ] When `head_dim > 256`, the MMA instruction can only process 256 columns at a time
|
||||
- [ ] `pv_n_tile = min(head_dim, 256)`, `n_pv_tiles = head_dim // pv_n_tile`
|
||||
- [ ] Multiple PV GEMM passes per KV tile, accumulating O
|
||||
- [ ] V must be re-constructed with `v_n = pv_n_tile` per pass
|
||||
- [ ] This may require multiple kernel launches at Python level (or a loop inside the kernel)
|
||||
- [ ] **Test:** hd=512, n=128 → correct output against FP32 oracle
|
||||
|
||||
### D1.4 — Cleanup and regression
|
||||
|
||||
- [ ] Remove `HEAD_DIM = 64` constant entirely
|
||||
- [ ] Add `head_dim` as first constructor arg (no default — always explicit)
|
||||
- [ ] Default `use_smem_p=None` → auto-detect from `head_dim > 64`
|
||||
- [ ] Test matrix: hd ∈ {64, 128, 256, 512} × n ∈ {128, 256}
|
||||
- [ ] Update README status table: D1 → ✅ COMPLETE
|
||||
- [ ] Cross off D1.0–D1.4 in this file
|
||||
|
||||
---
|
||||
|
||||
## D2 — Multi-query grid with head packing (after D1)
|
||||
|
||||
- [ ] Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`
|
||||
- [ ] DSV4 is MQA: all 128 query heads share same K/V
|
||||
- [ ] Head axis folded into M dimension of Q tile
|
||||
- [ ] **Test:** batch=4, T=64, n_h=128, num_kv_heads=1
|
||||
|
||||
## D3 — SWA sequence length mask
|
||||
|
||||
- [ ] Add `swa_lens: [batch] int32` kernel input
|
||||
- [ ] Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`
|
||||
- [ ] **Test:** varying SWA fill levels
|
||||
|
||||
## D4 — Causal mask on SWA branch
|
||||
|
||||
- [ ] Add `is_causal: bool` constructor flag
|
||||
- [ ] Apply `swa_idx > q_pos` masking in SWA pass
|
||||
- [ ] Main path has NO mask (indexer enforces causality upstream)
|
||||
|
||||
## D5 — SWA + sink merge
|
||||
|
||||
- [ ] D5a: Emit un-normalized `o` + `lse` instead of normalized `o` (keep normalize as flag)
|
||||
- [ ] D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python
|
||||
- [ ] D5c: Fuse two passes into one kernel launch (Q stays in SMEM)
|
||||
- [ ] D5d: Fuse sink merge into kernel epilogue
|
||||
|
||||
---
|
||||
|
||||
## Key References
|
||||
|
||||
| What | Where |
|
||||
|------|-------|
|
||||
| Working FMHA kernel (hd=64) | `dsv4/kernels/attention/fmha.py` — `FmhaKernel` |
|
||||
| Stage C test (oracle) | `tests/unit/test_fmha_v3_stage_c.py` — `FmhaV3StageCMulti` |
|
||||
| Stage A+B test | `tests/unit/test_fmha_v3.py` |
|
||||
| FlashMLA SM100 reference | `/root/dsv4-nvfp4-workspace/vllm/.deps/flashmla-src/csrc/cutlass/examples/python/CuTeDSL/blackwell/fmha.py` (on B200) |
|
||||
| CUTLASS FMHA reference | `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (on B200) |
|
||||
| Sink merge spec | `dsv4/ops/decode_sparse.py` |
|
||||
| SWA decode | `dsv4/ops/decode_swa.py` |
|
||||
| Attention reference | `dsv4/reference/attention.py` |
|
||||
| CSA attention reference | `dsv4/reference/csa_attention.py` |
|
||||
|
||||
## B200 Environment
|
||||
|
||||
```
|
||||
Server: root@45.76.247.107 (password: 6)Jr)B@dcX[mN?dx)
|
||||
Kernel repo: /root/dsv4-nvfp4-workspace/kernel
|
||||
Venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||||
PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
|
||||
Test command: python3 tests/unit/test_fmha_v3_stage_c.py
|
||||
```
|
||||
Reference in New Issue
Block a user