- README.md: Updated Stage status table (D1 🟡, D5 🟢), D5 section with D5a/D5b results, tOrP0 bug fix docs, new CuTeDSL constraints #11-12 - STAGE_D1.3.md: Added progress update - TMEM-P works, SMEM-P still blocked, recommended next steps - STAGE_D.md was already updated
629 lines
33 KiB
Markdown
629 lines
33 KiB
Markdown
# DSV4 Inference Kernel
|
||
|
||
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
|
||
|
||
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
|
||
|
||
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes** (verified on B200):
|
||
|
||
```
|
||
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
|
||
mode 0 1 2 3
|
||
```
|
||
|
||
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
|
||
|
||
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
|
||
|
||
```python
|
||
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
|
||
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
|
||
# TMA ALWAYS reads from tile 0.
|
||
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
|
||
|
||
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
|
||
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
|
||
```
|
||
|
||
### THE RIGHT WAY (verified on B200 at n=128 and n=256):
|
||
|
||
```python
|
||
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
|
||
# Mode 2 (KV tiles) survives as the second mode.
|
||
tBgK = tBgK[(None, 0, None, 0)]
|
||
|
||
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
|
||
cute.copy(tma_k, tBgK[None, kt], ...)
|
||
# ^^ THIS IS THE KV TILE DIM
|
||
```
|
||
|
||
**Verified shapes on B200 (May 22, n=256, inside @cute.kernel):**
|
||
```
|
||
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
|
||
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
|
||
```
|
||
|
||
### WHY THIS IS SO INSIDIOUS
|
||
|
||
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
|
||
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
|
||
3. **Multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
|
||
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
|
||
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
|
||
6. **The 8-mode theory was wrong.** We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
|
||
|
||
### THE LESSON
|
||
|
||
**PRINT THE SHAPES. ALWAYS.** Run `print(f"tBgK: shape={cute.shape(tBgK)}")` inside `@cute.kernel` at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
|
||
|
||
**The correct pre-slice depends on which mode is the GMEM tile iteration axis.** For our `local_tile` + `partition_B` + `group_modes(0,3)` pattern, mode 2 is the KV tile axis. `(None,0,None,0)` keeps it free. `(None,None,0,0)` collapses it to 0.
|
||
|
||
```python
|
||
# ALWAYS verify the shape at trace time:
|
||
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
|
||
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
|
||
|
||
# Then index the 2D tensor:
|
||
cute.copy(tma_k, tBgK[None, kt], ...)
|
||
```
|
||
|
||
**IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
|
||
|
||
---
|
||
|
||
## Architecture
|
||
|
||
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
|
||
|
||
```
|
||
DSV4 inference pipeline — component status
|
||
==========================================
|
||
|
||
Legend:
|
||
[✓] built and tested
|
||
[~] partial — reference or seam exists, native pending
|
||
[✗] to build
|
||
|
||
|
||
┌────────────────────────────────────┐
|
||
│ [✗] Embedding + mHC init │
|
||
│ token embed + n_hc=4 streams │
|
||
└────────────────┬───────────────────┘
|
||
│
|
||
▼
|
||
┌─ Transformer layer × L ──────────────────────────────────────────────┐
|
||
│ HCA on layers 0–1 of Pro, alternating CSA / HCA after │
|
||
│ │
|
||
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
|
||
│ │ [✓] Residual mHC pre + post mix │ │
|
||
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
|
||
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
|
||
│ │ [✓] Token compressor CSA m=4 / HCA m′=128 │ │
|
||
│ │ [✓] Indexer + top-k CSA, FP32 dot + top-k │ │
|
||
│ │ [~] FMHA core QK → online softmax → PV │ │
|
||
│ │ + SWA branch + sink merge │ │
|
||
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
|
||
│ └────────────────────────────────────────────────────────────────┘ │
|
||
│ │
|
||
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
|
||
│ │ [✓] Residual mHC pre + post mix │ │
|
||
│ │ [✓] Pre-FFN norm RMSNorm │ │
|
||
│ │ [✓] Router sqrt(softplus) + topk + hash │ │
|
||
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
|
||
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
|
||
│ └────────────────────────────────────────────────────────────────┘ │
|
||
└──────────────────────────────────┬───────────────────────────────────┘
|
||
│
|
||
▼
|
||
┌──────────────────────────────────────────────────────────────────────┐
|
||
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
|
||
└──────────────────────────────────────────────────────────────────────┘
|
||
|
||
┌─ Supporting infrastructure ──────────────────────────────────────────┐
|
||
│ [✗] KV cache management │
|
||
│ • state cache: SWA window + uncompressed tail per layer │
|
||
│ • classical paged cache: lcm(m, m′) = 128 tokens per block │
|
||
│ • heterogeneous layout per layer │
|
||
└──────────────────────────────────────────────────────────────────────┘
|
||
|
||
|
||
Summary
|
||
-------
|
||
Built [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
|
||
shared expert, token compressor, indexer+topk,
|
||
router, pre-FFN norm
|
||
Partial [~] : 3 — norms+RoPE, FMHA core
|
||
To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache
|
||
```
|
||
|
||
---
|
||
|
||
## Status (May 23, 2026 — 05:30 UTC)
|
||
|
||
| Stage | Status | Description |
|
||
|-------|--------|-------------|
|
||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||
| C | ✅ MIGRATED TO MODULE | Real online softmax + normalize. n=128 cos 0.973. Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. TMEM layout mismatch still present (3% error). |
|
||
| D1 | 🟡 MOSTLY DONE | Parameterized HEAD_DIM. TMEM-P hd=64 works (cos 0.973). SMEM-P for hd>64 is a stub (make_tiled_copy_C rank mismatch). tOrP0 TMEM column offset bug fixed. |
|
||
| D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) |
|
||
| D3 | TODO | SWA sequence length mask (swa_lens per batch) |
|
||
| D4 | TODO | Causal mask on SWA branch only |
|
||
| D5 | 🟢 D5a+D5b DONE | D5a: normalize flag + LSE output (err=0.0). D5b: Python SWA+sink merge (cos 0.961). D5c/D5d: fused kernel merge TODO. |
|
||
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
|
||
|
||
---
|
||
|
||
## Package Structure
|
||
|
||
```
|
||
dsv4/
|
||
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
|
||
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
|
||
│ ├── attention/ FMHA kernel — FmhaKernel (hd=64, TMEM-P proven; SMEM-P stub for hd>64)
|
||
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL, 419 lines)
|
||
│ ├── indexer/ CSA indexer — score+topk (FP32 dot products, top-k selection)
|
||
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
|
||
│ ├── cache/ Cache kernels — append_swa (write KV to split state cache layout)
|
||
│ ├── decode/ Decode-time attention (sparse, SWA — future)
|
||
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
|
||
├── ops/ PyTorch ↔ kernel bridges
|
||
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
|
||
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
|
||
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
|
||
│ ├── custom_ops.py torch.library.custom_op registrations
|
||
│ ├── decode_sparse.py native_sparse_decode dispatcher
|
||
│ ├── decode_swa.py native_swa_decode dispatcher
|
||
│ ├── rope.py Forward + inverse RoPE
|
||
│ ├── topk.py Python wrapper for sparse_topk_metadata.cu
|
||
│ ├── topk_select.py Top-k selection wrapper
|
||
│ └── router.py Router op bridge
|
||
├── layers/ nn.Module-style components
|
||
│ ├── linear.py Nvfp4Linear
|
||
│ ├── grouped_linear.py Nvfp4GroupedLinear
|
||
│ ├── moe.py Nvfp4MoE
|
||
│ ├── shared_expert.py Nvfp4SharedExpert
|
||
│ ├── mhc.py mHCLayer
|
||
│ ├── attention.py DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
|
||
│ ├── norm.py RMSNorm (PyTorch ref, fused kernel later)
|
||
│ ├── router.py Router — token-to-expert assignment (273 lines)
|
||
│ ├── embedding.py Token embedding + mHC init (stub)
|
||
│ └── ffn.py FFN sub-block
|
||
├── model/ Model assembly
|
||
│ ├── config.py Model config
|
||
│ ├── layer.py Transformer layer
|
||
│ ├── layer_schedule.py Layer scheduling
|
||
│ ├── mtp.py Multi-token prediction
|
||
│ ├── sampler.py Token sampler
|
||
│ └── dsv4.py Full model (stub — Phase 1)
|
||
├── cache/ KV cache infra
|
||
│ ├── allocator.py Cache memory allocator
|
||
│ ├── block_table.py Paged cache block table
|
||
│ ├── flush.py Cache flush
|
||
│ ├── handle.py Cache handle
|
||
│ ├── manager.py Cache manager
|
||
│ ├── paged_cache.py Paged KV cache
|
||
│ ├── prepare_forward.py Forward prep
|
||
│ ├── schema.py Cache schema
|
||
│ └── state_cache.py State cache (SWA ring buffer)
|
||
├── loader/ Checkpoint I/O
|
||
│ ├── hf_checkpoint.py HuggingFace checkpoint loader
|
||
│ └── layout_convert.py Weight layout conversion
|
||
└── reference/ Slow PyTorch oracles (never imported by production code)
|
||
├── attention.py RoPE, KV cache, causal attention, SWA
|
||
├── csa_attention.py CSA/HCA sparse attention
|
||
├── compressor.py Compressor PyTorch example
|
||
└── moe_pipeline.py MoE pipeline reference
|
||
```
|
||
|
||
**Mental model:** `kernels/` → `ops/` → `layers/` → `model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
|
||
|
||
---
|
||
|
||
## Active Test Files
|
||
|
||
### FMHA (Stages A/B/C/D1) — in `tests/unit/`
|
||
|
||
| File | Stage | Status |
|
||
|------|-------|--------|
|
||
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
|
||
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
|
||
| `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** |
|
||
| `test_fmha_v3_stage_d1.py` | D1 | 🟡 Parameterized hd, hd=64 PASS (cos 0.973), hd>64 FAIL (SMEM-P stub) |
|
||
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) |
|
||
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
|
||
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
|
||
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
|
||
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
|
||
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
|
||
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
|
||
|
||
### MoE / GEMM — in `tests/unit/`
|
||
|
||
| File | What |
|
||
|------|------|
|
||
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
|
||
| `cudagraph_test.py` | Cudagraph capture + replay |
|
||
| `layertest.py` | Per-layer correctness |
|
||
| `test_custom_op.py` | torch.library custom ops |
|
||
| `test_compile_custom_op.py` | Compile + warmup |
|
||
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
|
||
| `test_interleave.py` | Gate/up weight interleaving |
|
||
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
|
||
| `test_fused_step1.py` | Fused SwiGLU GEMM |
|
||
|
||
---
|
||
|
||
## Test Harness
|
||
|
||
Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107`):
|
||
|
||
### `run_test.sh` — Run a test in a screen session
|
||
|
||
```bash
|
||
# On the B200:
|
||
cd /root/dsv4-nvfp4-workspace/kernel
|
||
bash tests/run_test.sh tests/unit/test_fmha_v3.py
|
||
```
|
||
|
||
What it does:
|
||
1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes** (handles deadlocked GPU procs that ignore SIGHUP)
|
||
2. Deletes the old log file
|
||
3. Starts a new `screen -dmS kernel-test` running the test
|
||
4. Logs output to `/tmp/kernel-test.log`
|
||
5. Verifies the screen started
|
||
|
||
### `check_log.sh` — Check test progress
|
||
|
||
```bash
|
||
bash tests/check_log.sh
|
||
```
|
||
|
||
Shows the log contents and whether the screen is still running.
|
||
|
||
### Local → B200 workflow
|
||
|
||
```bash
|
||
# 1. Edit locally, commit, push
|
||
cd ~/dev/nvfp4-megamoe-kernel
|
||
git add -A && git commit -m "my change" && git push
|
||
|
||
# 2. SSH to B200, pull, run
|
||
ssh root@45.76.247.107
|
||
cd /root/dsv4-nvfp4-workspace/kernel && git pull
|
||
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
|
||
|
||
# 3. Check results
|
||
bash tests/check_log.sh
|
||
```
|
||
|
||
### `fire_b200_test` — One-command local test runner
|
||
|
||
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
|
||
|
||
```bash
|
||
# From your local machine, one command to push, run, and get results:
|
||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
|
||
```
|
||
|
||
What it does:
|
||
1. Auto-commits and pushes any local changes
|
||
2. SSH to B200, pulls, starts `run_test.sh` in a screen
|
||
3. Polls every 15s until the screen exits
|
||
4. Dumps the full test log to your terminal
|
||
|
||
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
|
||
|
||
---
|
||
|
||
|
||
## Stage C: Online Softmax — TMEM Layout Mismatch Issue
|
||
|
||
### Current Results (test_fmha_v3_stage_c.py)
|
||
|
||
| n | cos | Status |
|
||
|---|-----|--------|
|
||
| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch |
|
||
| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error |
|
||
| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles |
|
||
|
||
### Root Cause: TMEM Layout Mismatch
|
||
|
||
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses the **correct** C-fragment-compatible layout. **Raw PV output is perfect (cos 0.999998)** when `epilogue_tma_store` reads directly without any round-trip.
|
||
|
||
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.
|
||
|
||
**Fix proven but not yet integrated:** The `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern from CUTLASS's `cutlass.utils.gemm.sm100` reads O from TMEM using the correct `get_tmem_load_op` layout and writes to SMEM using `get_smem_store_op`. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper `flat_divide` and `tma_partition` handling inside the kernel's warp-specific if blocks.
|
||
|
||
### Key Bug Fix: tOrP0 TMEM Column Offset (May 23)
|
||
|
||
The softmax warps store P at `tmem_p0_offset=32` FP32 columns (64 BF16 elements). PV MMA must read from the same offset. **`tOrP0` was missing this offset**, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:
|
||
```python
|
||
if const_expr(self.tOrP0_offset > 0):
|
||
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
|
||
else:
|
||
tOrP0 = tOrP
|
||
```
|
||
Must use `const_expr` conditional (not Python `if`) because CuTeDSL compiles both branches, and `tOrP.iterator + 0` fails with MLIR type error.
|
||
|
||
### Architecture (6-warp, current)
|
||
|
||
```
|
||
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
|
||
Warp 4: MMA (QK, PV)
|
||
Warp 5: TMA (Q/K/V load)
|
||
```
|
||
|
||
### TMEM Layout
|
||
|
||
```
|
||
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
|
||
Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
|
||
Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
|
||
```
|
||
|
||
### Remaining for Multi-Tile Production
|
||
|
||
1. **Fix TMEM layout mismatch** — replace hand-constructed atom round-trips with correction_epilog pattern
|
||
2. **Pipeline state cycling for n≥384** — kv_stage=2 can only buffer 2 tiles
|
||
3. **12-warp layout** — separate softmax/correction/epilogue warps
|
||
4. **O rescale for kt > 0** — must also use paired atoms or correction_epilog
|
||
|
||
---
|
||
|
||
## CuTeDSL Constraints (hard-won)
|
||
|
||
1. **`vectorize=True` loops: ONLY load/store/print** — no fmax, no cmpf, no inner loops, no carry
|
||
2. **`.reduce(cute.ReductionOp.MAX)`:** reduces ENTIRE C-fragment to scalar — global max, not per-row
|
||
3. **`cute.arch.fmax`:** impure for vectorizer — use plain `range()` loop
|
||
4. **TMA partition tensors have 4 modes:** `(((64,128),1), ?, KV_tiles, ?)` — `(None,0,None,0)` keeps mode 2 (KV tiles) free, `[None, kt]` indexes it
|
||
5. **`tBgK[(None, None, 0, 0)]` pins mode 2 to 0** — silently reads tile 0 forever. Use `(None,0,None,0)` instead.
|
||
6. **`softmax_done_bar` NamedBarrier is reusable** across tiles
|
||
7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips.
|
||
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
|
||
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
|
||
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
|
||
11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails).
|
||
12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000.
|
||
|
||
---
|
||
|
||
## Key Lessons
|
||
|
||
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
|
||
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
|
||
3. **TMEM allocation must be power of 2.**
|
||
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
|
||
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
|
||
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
|
||
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
|
||
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
|
||
9. **PRINT THE SHAPES. ALWAYS.** Reasoning about TMEM layouts without evidence is how we waste days.
|
||
10. **Never assume TMEM round-trips are safe.** Verify with NO-OP tests before adding logic.
|
||
|
||
---
|
||
|
||
## Stage D: Full Decode Attention (revised May 23)
|
||
|
||
### Key Insight: The Indexer Solves Paging Upstream
|
||
|
||
The indexer now hands the kernel `selected_kv: [T, top_k, head_dim] BF16` — a **dense, materialized, dequantized** K/V tile. FMHA sees a dense `[T, top_k, 512]` tile, exactly like Stage A/B's existing `k` and `v` inputs. **The kernel doesn't need to know it's sparse.** Paged TMA, scattered HBM reads, FP8 dequantization — all handled by `gather_selected_kv` upstream.
|
||
|
||
The SWA branch is the only "irregular" thing: it reads from the state cache's ring buffer with a position mask. SWA is small (`n_win=128` per query), so it's a separate fused branch with a sink-weighted merge.
|
||
|
||
**One FMHA kernel serves all three DSV4 attention types:**
|
||
- **CSA:** `compressed_kv` = top-k from indexer, `swa_kv` from cache → sink merge
|
||
- **HCA:** `compressed_kv` = all classical pool entries (gather-all mode), `swa_kv` from cache → sink merge
|
||
- **SWA-only (Flash layers 0-1):** `compressed_kv` = empty (`top_k=0`), only SWA runs. Sink merge degenerates to just `o_swa` after renormalization.
|
||
|
||
### Build Order
|
||
|
||
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, in progress)
|
||
|
||
Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value).
|
||
|
||
**Two P staging paths:**
|
||
- **TMEM-P** (hd≤64): P stored to TMEM via register bridge. PV reads from TMEM. Proven at cos 0.973.
|
||
- **SMEM-P** (hd>64): P stored to SMEM via PV A-operand layout. PV reads from SMEM. Avoids QK↔PV TMEM layout mismatch at large hd. **Register→SMEM copy needs `make_tiled_copy_C(store_atom, qk_mma)` to partition threads by QK C-fragment.** The SMEM rendezvous pattern: softmax writes P to SMEM at logical (row, col) addresses using `p_smem_s` layout, MMA warp reads from same SMEM. Barrier in between.
|
||
|
||
Risk at HEAD_DIM=512: TMEM column budget. `_setup` already does `find_tmem_tensor_col_offset(tOtO)` dynamically. Verify the total fits in 512 TMEM columns. If not, reduce `kv_stage` from 2 to 1 (lose K/V double-buffering) before sacrificing math.
|
||
|
||
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.
|
||
|
||
**D2 — Multi-query grid with head packing** (~1 day)
|
||
|
||
Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`. DSV4 is MQA — all `n_h=128` query heads share the same K/V. The query-head axis is folded into the M dimension of the Q tile: `M_tile = 128` covers `M = T * n_h` rows. At decode T is small (1-16), so packing heads into M fills the MMA. At prefill T=64, M is already 8192 with heads packed.
|
||
|
||
Done when: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention against FP32 oracle.
|
||
|
||
**D3 — SWA sequence length mask** (~½ day)
|
||
|
||
The indexer's `top_k` is fixed (512 for Flash, 1024 for Pro). Compressed-K input is always `[T, top_k, head_dim]` with the same `top_k` at compile time.
|
||
|
||
What varies: the SWA window holds up to `n_win=128` tokens but starts with fewer. Add `swa_lens: [batch] int32` as kernel input. Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`.
|
||
|
||
Done when: batched input with varying SWA fill levels (some requests at position 50, some at 5000) produces correct masked output.
|
||
|
||
**D4 — Causal mask on SWA branch** (~½ day)
|
||
|
||
The compressed K the indexer selects is already from `s < floor(t/m)` (paper eq. 17). The indexer enforces causality at selection time. FMHA sees only causally-valid candidates. **The main path has no mask.**
|
||
|
||
The SWA branch needs a causal mask within the window. Add `is_causal: bool` constructor flag, apply `swa_idx > q_pos` masking to `-inf` in the SWA pass.
|
||
|
||
Done when: prefill mode produces correct output with the causal mask applied to SWA.
|
||
|
||
**D5 — SWA + sink merge** (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining
|
||
|
||
Per `dsv4/ops/decode_sparse.py`:
|
||
```
|
||
o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
|
||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||
```
|
||
|
||
With un-normalized O (D5a): `o_unnorm = o_norm * exp(lse)`, so:
|
||
```
|
||
o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa)
|
||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||
```
|
||
|
||
**D5a DONE (May 23):** `normalize` flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified.
|
||
|
||
**D5b DONE (May 23):** Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960.
|
||
|
||
Sub-steps remaining:
|
||
- **5c:** Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially.
|
||
- **5d:** Fuse the merge into the kernel epilogue.
|
||
|
||
Done when: end-to-end kernel produces correct attention against FP32 oracle that does sparse+SWA+sink merge.
|
||
|
||
**~~D5 (old) paged TMA~~ — REMOVED.** The indexer + gather handles all paging upstream.
|
||
|
||
### Kernel Architecture (after D5)
|
||
|
||
```
|
||
Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
|
||
swa_lens [batch], sink_logits [n_h], request_ids [T]
|
||
│
|
||
├─ Load Q to SMEM (once)
|
||
│
|
||
├─ Loop 1: compressed KV (top_k tokens)
|
||
│ QK → online softmax → PV → O_sparse, lse_sparse in TMEM
|
||
│
|
||
├─ Loop 2: SWA window (n_win tokens, masked by swa_lens)
|
||
│ QK → online softmax → PV → O_swa, lse_swa in TMEM
|
||
│
|
||
└─ Sink merge epilogue:
|
||
O = (exp(lse_sparse) * O_sparse + exp(sink) * exp(lse_swa) * O_swa)
|
||
/ (exp(lse_sparse) + exp(sink) * exp(lse_swa))
|
||
```
|
||
|
||
### Reference Files
|
||
|
||
- Sink merge spec: `dsv4/ops/decode_sparse.py` (formula)
|
||
- SWA decode: `dsv4/ops/decode_swa.py`
|
||
- Attention reference: `dsv4/reference/attention.py`
|
||
- CSA attention: `dsv4/reference/csa_attention.py`
|
||
|
||
### Stage C Note
|
||
|
||
When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" to "emit un-normalized o + lse". Defer this until D5. Through D1-D4, keep Stage C normalize as-is and test as standalone dense FMHA.
|
||
|
||
---
|
||
|
||
## Stage E: Production Extraction (revised May 23)
|
||
|
||
### E1 — File placement
|
||
|
||
`dsv4/kernels/attention/fmha.py`. Currently contains `FmhaKernel` (migrated from test, hd=64 TMEM-P). Will gain parameterized `head_dim` and SMEM-P path in D1. Constructor takes all dimensions and dtypes, no module-level constants.
|
||
|
||
### E2 — Constructor signature
|
||
|
||
```python
|
||
class FmhaKernel:
|
||
def __init__(
|
||
self,
|
||
head_dim: int, # 512 for DSV4
|
||
num_query_heads: int, # 128 for Pro, 64 for Flash
|
||
sliding_window: int, # 128
|
||
top_k: int, # 512 (Flash) or 1024 (Pro)
|
||
q_dtype=BFloat16,
|
||
kv_dtype=BFloat16,
|
||
o_dtype=BFloat16,
|
||
qk_acc_dtype=Float32,
|
||
pv_acc_dtype=Float32,
|
||
is_causal: bool = False, # affects SWA mask only
|
||
cta_group: tcgen05.CtaGroup = tcgen05.CtaGroup.ONE,
|
||
cluster_shape_mn: tuple = (1, 1),
|
||
):
|
||
```
|
||
|
||
All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.
|
||
|
||
### E3 — Call signature
|
||
|
||
```python
|
||
def __call__(
|
||
self,
|
||
q: torch.Tensor, # [T, n_h, head_dim] BF16
|
||
compressed_kv: torch.Tensor, # [T, top_k, head_dim] BF16 — from indexer gather
|
||
swa_kv: torch.Tensor, # [batch, n_win, head_dim] BF16 — from cache prep
|
||
swa_lens: torch.Tensor, # [batch] int32
|
||
sink_logits: torch.Tensor, # [n_h] FP32
|
||
request_ids: torch.Tensor, # [T] int32 — maps query to its SWA slot
|
||
o: torch.Tensor, # [T, n_h, head_dim] BF16 — preallocated
|
||
stream: cuda.CUstream,
|
||
):
|
||
```
|
||
|
||
Notably absent: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream.
|
||
|
||
### E4 — Kernel cache + warmup
|
||
|
||
Mirror `dsv4/ops/gemm_runner.py`'s `_compiled_kernel_cache`. Key on `(head_dim, num_query_heads, top_k, is_causal, ...)`. Pre-allocate at warmup, reuse at call. For DSV4, the cache has at most ~2 entries (Flash/Pro × causal/non).
|
||
|
||
### E5 — torch.library custom op
|
||
|
||
```python
|
||
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))
|
||
def sparse_fmha_with_swa_op(
|
||
q: torch.Tensor,
|
||
compressed_kv: torch.Tensor,
|
||
swa_kv: torch.Tensor,
|
||
swa_lens: torch.Tensor,
|
||
sink_logits: torch.Tensor,
|
||
request_ids: torch.Tensor,
|
||
o: torch.Tensor,
|
||
runner_id: int,
|
||
) -> None:
|
||
runner = get_runner(runner_id)
|
||
runner._run_impl(q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o)
|
||
```
|
||
|
||
Mutates `o` (preallocated buffer). Consistent with cudagraphs.
|
||
|
||
### E6 — Reference parity hook
|
||
|
||
`dsv4/reference/attention.py` stays as the FP32 oracle. New test: `tests/unit/test_fmha_kernel.py`.
|
||
|
||
```python
|
||
def test_sparse_fmha_matches_spec(T=64, n_h=128, top_k=1024, n_win=128, hd=512):
|
||
q = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device='cuda')
|
||
ck = torch.randn(T, top_k, hd, dtype=torch.bfloat16, device='cuda')
|
||
swa = torch.randn(4, n_win, hd, dtype=torch.bf16, device='cuda')
|
||
swa_lens = torch.tensor([128, 50, 128, 75], dtype=torch.int32)
|
||
sink = torch.randn(n_h, device='cuda')
|
||
req_ids = torch.randint(0, 4, (T,), dtype=torch.int32)
|
||
|
||
# Oracle: pure FP32 spec
|
||
o_sparse, lse_sparse = attention_with_lse_f32(q, ck, ck)
|
||
o_swa, lse_swa = attention_swa_with_lse_f32(q, swa, swa, swa_lens, req_ids)
|
||
e_sink = sink.exp()
|
||
num = lse_sparse.exp().unsqueeze(-1) * o_sparse \
|
||
+ e_sink[None, :, None] * lse_swa.exp().unsqueeze(-1) * o_swa
|
||
den = lse_sparse.exp() + e_sink[None, :] * lse_swa.exp()
|
||
expected = num / den.unsqueeze(-1)
|
||
|
||
# Kernel
|
||
o = torch.empty_like(expected, dtype=torch.bfloat16)
|
||
fmha = FmhaKernel(head_dim=hd, num_query_heads=n_h, sliding_window=n_win, top_k=top_k)
|
||
fmha(q, ck, swa, swa_lens, sink, req_ids, o, stream=...)
|
||
|
||
torch.testing.assert_close(o.float(), expected, atol=5e-3, rtol=5e-3)
|
||
```
|
||
|
||
### E7 — Cleanup
|
||
|
||
Delete all debug test files. `test_fmha_v3.py` becomes `dsv4/kernels/attention/fmha.py`. Only `tests/unit/test_fmha_kernel.py` remains as the attention test.
|
||
|
||
---
|
||
|
||
## Environment
|
||
|
||
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
|
||
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
|
||
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||
- Local CUTLASS clone: `/home/openclaw/dev/cutlass`
|