Files
nvfp4-megamoe-kernel/README.md
biondizzle 98e5b48470 Update all .md files with D5a/D5b progress, tOrP0 fix, LSE formula
- README.md: Updated Stage status table (D1 🟡, D5 🟢), D5 section with
  D5a/D5b results, tOrP0 bug fix docs, new CuTeDSL constraints #11-12
- STAGE_D1.3.md: Added progress update - TMEM-P works, SMEM-P still blocked,
  recommended next steps
- STAGE_D.md was already updated
2026-05-23 22:07:53 +00:00

629 lines
33 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# DSV4 Inference Kernel
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes** (verified on B200):
```
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
mode 0 1 2 3
```
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
```python
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
# TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
```
### THE RIGHT WAY (verified on B200 at n=128 and n=256):
```python
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
# Mode 2 (KV tiles) survives as the second mode.
tBgK = tBgK[(None, 0, None, 0)]
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
cute.copy(tma_k, tBgK[None, kt], ...)
# ^^ THIS IS THE KV TILE DIM
```
**Verified shapes on B200 (May 22, n=256, inside @cute.kernel):**
```
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
```
### WHY THIS IS SO INSIDIOUS
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
3. **Multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
6. **The 8-mode theory was wrong.** We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
### THE LESSON
**PRINT THE SHAPES. ALWAYS.** Run `print(f"tBgK: shape={cute.shape(tBgK)}")` inside `@cute.kernel` at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
**The correct pre-slice depends on which mode is the GMEM tile iteration axis.** For our `local_tile` + `partition_B` + `group_modes(0,3)` pattern, mode 2 is the KV tile axis. `(None,0,None,0)` keeps it free. `(None,None,0,0)` collapses it to 0.
```python
# ALWAYS verify the shape at trace time:
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
# Then index the 2D tensor:
cute.copy(tma_k, tBgK[None, kt], ...)
```
**IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
---
## Architecture
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
```
DSV4 inference pipeline — component status
==========================================
Legend:
[✓] built and tested
[~] partial — reference or seam exists, native pending
[✗] to build
┌────────────────────────────────────┐
│ [✗] Embedding + mHC init │
│ token embed + n_hc=4 streams │
└────────────────┬───────────────────┘
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after │
│ │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [✓] Token compressor CSA m=4 / HCA m=128 │ │
│ │ [✓] Indexer + top-k CSA, FP32 dot + top-k │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [✓] Pre-FFN norm RMSNorm │ │
│ │ [✓] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────┬───────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘
┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management │
│ • state cache: SWA window + uncompressed tail per layer │
│ • classical paged cache: lcm(m, m) = 128 tokens per block │
│ • heterogeneous layout per layer │
└──────────────────────────────────────────────────────────────────────┘
Summary
-------
Built [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert, token compressor, indexer+topk,
router, pre-FFN norm
Partial [~] : 3 — norms+RoPE, FMHA core
To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache
```
---
## Status (May 23, 2026 — 05:30 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ✅ MIGRATED TO MODULE | Real online softmax + normalize. n=128 cos 0.973. Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. TMEM layout mismatch still present (3% error). |
| D1 | 🟡 MOSTLY DONE | Parameterized HEAD_DIM. TMEM-P hd=64 works (cos 0.973). SMEM-P for hd>64 is a stub (make_tiled_copy_C rank mismatch). tOrP0 TMEM column offset bug fixed. |
| D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) |
| D3 | TODO | SWA sequence length mask (swa_lens per batch) |
| D4 | TODO | Causal mask on SWA branch only |
| D5 | 🟢 D5a+D5b DONE | D5a: normalize flag + LSE output (err=0.0). D5b: Python SWA+sink merge (cos 0.961). D5c/D5d: fused kernel merge TODO. |
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
---
## Package Structure
```
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel — FmhaKernel (hd=64, TMEM-P proven; SMEM-P stub for hd>64)
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL, 419 lines)
│ ├── indexer/ CSA indexer — score+topk (FP32 dot products, top-k selection)
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
│ ├── cache/ Cache kernels — append_swa (write KV to split state cache layout)
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
│ ├── topk.py Python wrapper for sparse_topk_metadata.cu
│ ├── topk_select.py Top-k selection wrapper
│ └── router.py Router op bridge
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
│ ├── attention.py DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
│ ├── norm.py RMSNorm (PyTorch ref, fused kernel later)
│ ├── router.py Router — token-to-expert assignment (273 lines)
│ ├── embedding.py Token embedding + mHC init (stub)
│ └── ffn.py FFN sub-block
├── model/ Model assembly
│ ├── config.py Model config
│ ├── layer.py Transformer layer
│ ├── layer_schedule.py Layer scheduling
│ ├── mtp.py Multi-token prediction
│ ├── sampler.py Token sampler
│ └── dsv4.py Full model (stub — Phase 1)
├── cache/ KV cache infra
│ ├── allocator.py Cache memory allocator
│ ├── block_table.py Paged cache block table
│ ├── flush.py Cache flush
│ ├── handle.py Cache handle
│ ├── manager.py Cache manager
│ ├── paged_cache.py Paged KV cache
│ ├── prepare_forward.py Forward prep
│ ├── schema.py Cache schema
│ └── state_cache.py State cache (SWA ring buffer)
├── loader/ Checkpoint I/O
│ ├── hf_checkpoint.py HuggingFace checkpoint loader
│ └── layout_convert.py Weight layout conversion
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
├── compressor.py Compressor PyTorch example
└── moe_pipeline.py MoE pipeline reference
```
**Mental model:** `kernels/``ops/``layers/``model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
---
## Active Test Files
### FMHA (Stages A/B/C/D1) — in `tests/unit/`
| File | Stage | Status |
|------|-------|--------|
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
| `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** |
| `test_fmha_v3_stage_d1.py` | D1 | 🟡 Parameterized hd, hd=64 PASS (cos 0.973), hd>64 FAIL (SMEM-P stub) |
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) |
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
### MoE / GEMM — in `tests/unit/`
| File | What |
|------|------|
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
| `cudagraph_test.py` | Cudagraph capture + replay |
| `layertest.py` | Per-layer correctness |
| `test_custom_op.py` | torch.library custom ops |
| `test_compile_custom_op.py` | Compile + warmup |
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
| `test_interleave.py` | Gate/up weight interleaving |
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
| `test_fused_step1.py` | Fused SwiGLU GEMM |
---
## Test Harness
Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107`):
### `run_test.sh` — Run a test in a screen session
```bash
# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py
```
What it does:
1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes** (handles deadlocked GPU procs that ignore SIGHUP)
2. Deletes the old log file
3. Starts a new `screen -dmS kernel-test` running the test
4. Logs output to `/tmp/kernel-test.log`
5. Verifies the screen started
### `check_log.sh` — Check test progress
```bash
bash tests/check_log.sh
```
Shows the log contents and whether the screen is still running.
### Local → B200 workflow
```bash
# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push
# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
# 3. Check results
bash tests/check_log.sh
```
### `fire_b200_test` — One-command local test runner
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
```bash
# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
```
What it does:
1. Auto-commits and pushes any local changes
2. SSH to B200, pulls, starts `run_test.sh` in a screen
3. Polls every 15s until the screen exits
4. Dumps the full test log to your terminal
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
---
## Stage C: Online Softmax — TMEM Layout Mismatch Issue
### Current Results (test_fmha_v3_stage_c.py)
| n | cos | Status |
|---|-----|--------|
| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch |
| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error |
| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles |
### Root Cause: TMEM Layout Mismatch
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses the **correct** C-fragment-compatible layout. **Raw PV output is perfect (cos 0.999998)** when `epilogue_tma_store` reads directly without any round-trip.
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.
**Fix proven but not yet integrated:** The `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern from CUTLASS's `cutlass.utils.gemm.sm100` reads O from TMEM using the correct `get_tmem_load_op` layout and writes to SMEM using `get_smem_store_op`. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper `flat_divide` and `tma_partition` handling inside the kernel's warp-specific if blocks.
### Key Bug Fix: tOrP0 TMEM Column Offset (May 23)
The softmax warps store P at `tmem_p0_offset=32` FP32 columns (64 BF16 elements). PV MMA must read from the same offset. **`tOrP0` was missing this offset**, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:
```python
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
```
Must use `const_expr` conditional (not Python `if`) because CuTeDSL compiles both branches, and `tOrP.iterator + 0` fails with MLIR type error.
### Architecture (6-warp, current)
```
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
```
### TMEM Layout
```
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
```
### Remaining for Multi-Tile Production
1. **Fix TMEM layout mismatch** — replace hand-constructed atom round-trips with correction_epilog pattern
2. **Pipeline state cycling for n≥384** — kv_stage=2 can only buffer 2 tiles
3. **12-warp layout** — separate softmax/correction/epilogue warps
4. **O rescale for kt > 0** — must also use paired atoms or correction_epilog
---
## CuTeDSL Constraints (hard-won)
1. **`vectorize=True` loops: ONLY load/store/print** — no fmax, no cmpf, no inner loops, no carry
2. **`.reduce(cute.ReductionOp.MAX)`:** reduces ENTIRE C-fragment to scalar — global max, not per-row
3. **`cute.arch.fmax`:** impure for vectorizer — use plain `range()` loop
4. **TMA partition tensors have 4 modes:** `(((64,128),1), ?, KV_tiles, ?)``(None,0,None,0)` keeps mode 2 (KV tiles) free, `[None, kt]` indexes it
5. **`tBgK[(None, None, 0, 0)]` pins mode 2 to 0** — silently reads tile 0 forever. Use `(None,0,None,0)` instead.
6. **`softmax_done_bar` NamedBarrier is reusable** across tiles
7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips.
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails).
12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000.
---
## Key Lessons
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
3. **TMEM allocation must be power of 2.**
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
9. **PRINT THE SHAPES. ALWAYS.** Reasoning about TMEM layouts without evidence is how we waste days.
10. **Never assume TMEM round-trips are safe.** Verify with NO-OP tests before adding logic.
---
## Stage D: Full Decode Attention (revised May 23)
### Key Insight: The Indexer Solves Paging Upstream
The indexer now hands the kernel `selected_kv: [T, top_k, head_dim] BF16` — a **dense, materialized, dequantized** K/V tile. FMHA sees a dense `[T, top_k, 512]` tile, exactly like Stage A/B's existing `k` and `v` inputs. **The kernel doesn't need to know it's sparse.** Paged TMA, scattered HBM reads, FP8 dequantization — all handled by `gather_selected_kv` upstream.
The SWA branch is the only "irregular" thing: it reads from the state cache's ring buffer with a position mask. SWA is small (`n_win=128` per query), so it's a separate fused branch with a sink-weighted merge.
**One FMHA kernel serves all three DSV4 attention types:**
- **CSA:** `compressed_kv` = top-k from indexer, `swa_kv` from cache → sink merge
- **HCA:** `compressed_kv` = all classical pool entries (gather-all mode), `swa_kv` from cache → sink merge
- **SWA-only (Flash layers 0-1):** `compressed_kv` = empty (`top_k=0`), only SWA runs. Sink merge degenerates to just `o_swa` after renormalization.
### Build Order
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, in progress)
Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value).
**Two P staging paths:**
- **TMEM-P** (hd≤64): P stored to TMEM via register bridge. PV reads from TMEM. Proven at cos 0.973.
- **SMEM-P** (hd>64): P stored to SMEM via PV A-operand layout. PV reads from SMEM. Avoids QK↔PV TMEM layout mismatch at large hd. **Register→SMEM copy needs `make_tiled_copy_C(store_atom, qk_mma)` to partition threads by QK C-fragment.** The SMEM rendezvous pattern: softmax writes P to SMEM at logical (row, col) addresses using `p_smem_s` layout, MMA warp reads from same SMEM. Barrier in between.
Risk at HEAD_DIM=512: TMEM column budget. `_setup` already does `find_tmem_tensor_col_offset(tOtO)` dynamically. Verify the total fits in 512 TMEM columns. If not, reduce `kv_stage` from 2 to 1 (lose K/V double-buffering) before sacrificing math.
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.
**D2 — Multi-query grid with head packing** (~1 day)
Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`. DSV4 is MQA — all `n_h=128` query heads share the same K/V. The query-head axis is folded into the M dimension of the Q tile: `M_tile = 128` covers `M = T * n_h` rows. At decode T is small (1-16), so packing heads into M fills the MMA. At prefill T=64, M is already 8192 with heads packed.
Done when: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention against FP32 oracle.
**D3 — SWA sequence length mask** (~½ day)
The indexer's `top_k` is fixed (512 for Flash, 1024 for Pro). Compressed-K input is always `[T, top_k, head_dim]` with the same `top_k` at compile time.
What varies: the SWA window holds up to `n_win=128` tokens but starts with fewer. Add `swa_lens: [batch] int32` as kernel input. Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`.
Done when: batched input with varying SWA fill levels (some requests at position 50, some at 5000) produces correct masked output.
**D4 — Causal mask on SWA branch** (~½ day)
The compressed K the indexer selects is already from `s < floor(t/m)` (paper eq. 17). The indexer enforces causality at selection time. FMHA sees only causally-valid candidates. **The main path has no mask.**
The SWA branch needs a causal mask within the window. Add `is_causal: bool` constructor flag, apply `swa_idx > q_pos` masking to `-inf` in the SWA pass.
Done when: prefill mode produces correct output with the causal mask applied to SWA.
**D5 — SWA + sink merge** (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining
Per `dsv4/ops/decode_sparse.py`:
```
o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
```
With un-normalized O (D5a): `o_unnorm = o_norm * exp(lse)`, so:
```
o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa)
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
```
**D5a DONE (May 23):** `normalize` flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified.
**D5b DONE (May 23):** Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960.
Sub-steps remaining:
- **5c:** Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially.
- **5d:** Fuse the merge into the kernel epilogue.
Done when: end-to-end kernel produces correct attention against FP32 oracle that does sparse+SWA+sink merge.
**~~D5 (old) paged TMA~~ — REMOVED.** The indexer + gather handles all paging upstream.
### Kernel Architecture (after D5)
```
Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
swa_lens [batch], sink_logits [n_h], request_ids [T]
├─ Load Q to SMEM (once)
├─ Loop 1: compressed KV (top_k tokens)
│ QK → online softmax → PV → O_sparse, lse_sparse in TMEM
├─ Loop 2: SWA window (n_win tokens, masked by swa_lens)
│ QK → online softmax → PV → O_swa, lse_swa in TMEM
└─ Sink merge epilogue:
O = (exp(lse_sparse) * O_sparse + exp(sink) * exp(lse_swa) * O_swa)
/ (exp(lse_sparse) + exp(sink) * exp(lse_swa))
```
### Reference Files
- Sink merge spec: `dsv4/ops/decode_sparse.py` (formula)
- SWA decode: `dsv4/ops/decode_swa.py`
- Attention reference: `dsv4/reference/attention.py`
- CSA attention: `dsv4/reference/csa_attention.py`
### Stage C Note
When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" to "emit un-normalized o + lse". Defer this until D5. Through D1-D4, keep Stage C normalize as-is and test as standalone dense FMHA.
---
## Stage E: Production Extraction (revised May 23)
### E1 — File placement
`dsv4/kernels/attention/fmha.py`. Currently contains `FmhaKernel` (migrated from test, hd=64 TMEM-P). Will gain parameterized `head_dim` and SMEM-P path in D1. Constructor takes all dimensions and dtypes, no module-level constants.
### E2 — Constructor signature
```python
class FmhaKernel:
def __init__(
self,
head_dim: int, # 512 for DSV4
num_query_heads: int, # 128 for Pro, 64 for Flash
sliding_window: int, # 128
top_k: int, # 512 (Flash) or 1024 (Pro)
q_dtype=BFloat16,
kv_dtype=BFloat16,
o_dtype=BFloat16,
qk_acc_dtype=Float32,
pv_acc_dtype=Float32,
is_causal: bool = False, # affects SWA mask only
cta_group: tcgen05.CtaGroup = tcgen05.CtaGroup.ONE,
cluster_shape_mn: tuple = (1, 1),
):
```
All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.
### E3 — Call signature
```python
def __call__(
self,
q: torch.Tensor, # [T, n_h, head_dim] BF16
compressed_kv: torch.Tensor, # [T, top_k, head_dim] BF16 — from indexer gather
swa_kv: torch.Tensor, # [batch, n_win, head_dim] BF16 — from cache prep
swa_lens: torch.Tensor, # [batch] int32
sink_logits: torch.Tensor, # [n_h] FP32
request_ids: torch.Tensor, # [T] int32 — maps query to its SWA slot
o: torch.Tensor, # [T, n_h, head_dim] BF16 — preallocated
stream: cuda.CUstream,
):
```
Notably absent: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream.
### E4 — Kernel cache + warmup
Mirror `dsv4/ops/gemm_runner.py`'s `_compiled_kernel_cache`. Key on `(head_dim, num_query_heads, top_k, is_causal, ...)`. Pre-allocate at warmup, reuse at call. For DSV4, the cache has at most ~2 entries (Flash/Pro × causal/non).
### E5 — torch.library custom op
```python
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))
def sparse_fmha_with_swa_op(
q: torch.Tensor,
compressed_kv: torch.Tensor,
swa_kv: torch.Tensor,
swa_lens: torch.Tensor,
sink_logits: torch.Tensor,
request_ids: torch.Tensor,
o: torch.Tensor,
runner_id: int,
) -> None:
runner = get_runner(runner_id)
runner._run_impl(q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o)
```
Mutates `o` (preallocated buffer). Consistent with cudagraphs.
### E6 — Reference parity hook
`dsv4/reference/attention.py` stays as the FP32 oracle. New test: `tests/unit/test_fmha_kernel.py`.
```python
def test_sparse_fmha_matches_spec(T=64, n_h=128, top_k=1024, n_win=128, hd=512):
q = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device='cuda')
ck = torch.randn(T, top_k, hd, dtype=torch.bfloat16, device='cuda')
swa = torch.randn(4, n_win, hd, dtype=torch.bf16, device='cuda')
swa_lens = torch.tensor([128, 50, 128, 75], dtype=torch.int32)
sink = torch.randn(n_h, device='cuda')
req_ids = torch.randint(0, 4, (T,), dtype=torch.int32)
# Oracle: pure FP32 spec
o_sparse, lse_sparse = attention_with_lse_f32(q, ck, ck)
o_swa, lse_swa = attention_swa_with_lse_f32(q, swa, swa, swa_lens, req_ids)
e_sink = sink.exp()
num = lse_sparse.exp().unsqueeze(-1) * o_sparse \
+ e_sink[None, :, None] * lse_swa.exp().unsqueeze(-1) * o_swa
den = lse_sparse.exp() + e_sink[None, :] * lse_swa.exp()
expected = num / den.unsqueeze(-1)
# Kernel
o = torch.empty_like(expected, dtype=torch.bfloat16)
fmha = FmhaKernel(head_dim=hd, num_query_heads=n_h, sliding_window=n_win, top_k=top_k)
fmha(q, ck, swa, swa_lens, sink, req_ids, o, stream=...)
torch.testing.assert_close(o.float(), expected, atol=5e-3, rtol=5e-3)
```
### E7 — Cleanup
Delete all debug test files. `test_fmha_v3.py` becomes `dsv4/kernels/attention/fmha.py`. Only `tests/unit/test_fmha_kernel.py` remains as the attention test.
---
## Environment
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
- Local CUTLASS clone: `/home/openclaw/dev/cutlass`