# DSV4 Inference Kernel ## Architecture DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different. ``` DSV4 inference pipeline — component status ========================================== Legend: [✓] built and tested [~] partial — reference or seam exists, native pending [✗] to build ┌────────────────────────────────────┐ │ [✗] Embedding + mHC init │ │ token embed + n_hc=4 streams │ └────────────────┬───────────────────┘ │ ▼ ┌─ Transformer layer × L ──────────────────────────────────────────────┐ │ HCA on layers 0–1 of Pro, alternating CSA / HCA after │ │ │ │ ┌─ Attention sub-block ──────────────────────────────────────────┐ │ │ │ [✓] Residual mHC pre + post mix │ │ │ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │ │ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │ │ │ [~] Token compressor CSA m=4 / HCA m′=128 │ │ │ │ [✗] Indexer + top-k CSA only, FP4 QK │ │ │ │ [~] FMHA core QK → online softmax → PV │ │ │ │ + SWA branch + sink merge │ │ │ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │ │ └────────────────────────────────────────────────────────────────┘ │ │ │ │ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │ │ │ [✓] Residual mHC pre + post mix │ │ │ │ [~] Pre-FFN norm RMSNorm │ │ │ │ [✗] Router sqrt(softplus) + topk + hash │ │ │ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │ │ │ [✓] Shared expert NVFP4 single-group GEMM │ │ │ └────────────────────────────────────────────────────────────────┘ │ └──────────────────────────────────┬───────────────────────────────────┘ │ ▼ ┌──────────────────────────────────────────────────────────────────────┐ │ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │ └──────────────────────────────────────────────────────────────────────┘ ┌─ Supporting infrastructure ──────────────────────────────────────────┐ │ [✗] KV cache management │ │ • state cache: SWA window + uncompressed tail per layer │ │ • classical paged cache: lcm(m, m′) = 128 tokens per block │ │ • heterogeneous layout per layer │ └──────────────────────────────────────────────────────────────────────┘ Summary ------- Built [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE, shared expert Partial [~] : 4 — norms+RoPE, token compressor, FMHA core, pre-FFN norm To build [✗] : 8 — embedding+init, indexer+top-k, router, final norm, LM head, MTP, sampler, KV cache ``` --- ## Status (May 22, 2026 — 16:30 UTC) | Stage | Status | Description | |-------|--------|-------------| | A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM | | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | | C | ⚠️ SINGLE-TILE ONLY | Real online softmax works for n=128 (cosine 0.993-0.996). **Multi-tile (n>128) broken.** | | C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. | | D | TODO | Full decode attention: paged KV cache, multi-query, causal mask | | E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge | --- ## Package Structure ``` dsv4/ ├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files) │ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler) │ ├── attention/ FMHA kernel (stub — extraction is Stage E) │ ├── compressor/ CSA/HCA token-level compressor │ ├── decode/ Decode-time attention (sparse, SWA — future) │ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata) ├── ops/ PyTorch ↔ kernel bridges │ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors │ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets │ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs │ ├── custom_ops.py torch.library.custom_op registrations │ ├── decode_sparse.py native_sparse_decode dispatcher │ ├── decode_swa.py native_swa_decode dispatcher │ ├── rope.py Forward + inverse RoPE │ └── topk.py Python wrapper for sparse_topk_metadata.cu ├── layers/ nn.Module-style components │ ├── linear.py Nvfp4Linear │ ├── grouped_linear.py Nvfp4GroupedLinear │ ├── moe.py Nvfp4MoE │ ├── shared_expert.py Nvfp4SharedExpert │ ├── mhc.py mHCLayer │ └── (stubs: attention, ffn, router, norm, embedding) ├── model/ Model assembly (stubs — Phase 1) ├── cache/ KV cache infra (stubs — Phase 3) ├── loader/ Checkpoint I/O (stubs — Phase 1) └── reference/ Slow PyTorch oracles (never imported by production code) ├── attention.py RoPE, KV cache, causal attention, SWA ├── csa_attention.py CSA/HCA sparse attention ├── compressor.py Compressor PyTorch example └── moe_pipeline.py MoE pipeline reference ``` **Mental model:** `kernels/` → `ops/` → `layers/` → `model/` (dependency flows left to right). `reference/` and `loader/` are sidecars. --- ## Active Test Files ### FMHA (Stages A/B/C) — in `tests/unit/` | File | Stage | Status | |------|-------|--------| | `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 | | `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 | | `test_fmha_v3_stage_c_full.py` | C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 | | `test_fmha_v3_stage_c_min.py` | C | 🔨 Early 12-warp pipeline (broken pipeline state) | | `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline | | `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline | | `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines | | `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV | ### MoE / GEMM — in `tests/unit/` | File | What | |------|------| | `test_cutedsl.py` | NVFP4 grouped GEMM kernel | | `cudagraph_test.py` | Cudagraph capture + replay | | `layertest.py` | Per-layer correctness | | `test_custom_op.py` | torch.library custom ops | | `test_compile_custom_op.py` | Compile + warmup | | `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip | | `test_interleave.py` | Gate/up weight interleaving | | `test_interleave_gemm.py` | Interleaved GEMM correctness | | `test_fused_step1.py` | Fused SwiGLU GEMM | ### Archived Tests `tests/archive/` contains ~190 debug files from Stages A/B. Not maintained. Can be deleted. --- ## Test Harness Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107`): ### `run_test.sh` — Run a test in a screen session ```bash # On the B200: cd /root/dsv4-nvfp4-workspace/kernel bash tests/run_test.sh tests/unit/test_fmha_v3.py ``` What it does: 1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes** (handles deadlocked GPU procs that ignore SIGHUP) 2. Deletes the old log file 3. Starts a new `screen -dmS kernel-test` running the test 4. Logs output to `/tmp/kernel-test.log` 5. Verifies the screen started ### `check_log.sh` — Check test progress ```bash bash tests/check_log.sh ``` Shows the log contents and whether the screen is still running. ### Local → B200 workflow ```bash # 1. Edit locally, commit, push cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "my change" && git push # 2. SSH to B200, pull, run ssh root@45.76.247.107 cd /root/dsv4-nvfp4-workspace/kernel && git pull bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py # 3. Check results bash tests/check_log.sh ``` ### `fire_b200_test` — One-command local test runner Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling). ```bash # From your local machine, one command to push, run, and get results: ~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py ``` What it does: 1. Auto-commits and pushes any local changes 2. SSH to B200, pulls, starts `run_test.sh` in a screen 3. Polls every 15s until the screen exits 4. Dumps the full test log to your terminal **This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote. --- ## Stage C: Online Softmax — SINGLE-TILE ONLY ### What We Have **Working real softmax** for single KV tile (n=128) in `test_fmha_v3_stage_c_full.py`: cosine 0.993-0.996. **Multi-tile (n>128) is broken** — see blocker below. ### Multi-Tile Blocker: TMA GMEM Tile Indexing The TMA partition slices `tBgK`/`tVgV` with `(None, 0, None, 0)`. The free mode after slicing is the GMEM iteration dimension. A `kv_coord` variable is used to index it. **Problem: the `kv_coord` increment is not propagating to the TMA at runtime.** **Evidence (May 22):** - `kv_coord = Int32(0)` + `kv_coord += 1` in `cutlass.range` loop → all multi-tile outputs identical (TMA loads from tile 0 every iteration) - `kv_coord = 0` (plain Python int) + `kv_coord += 1` → same broken result - `kv_coord = Int32(1)` hardcoded → output **changes** (TMA CAN load from tile 1, the coordinate just isn't being dynamically updated) - Pipeline handle `.count` also doesn't work (it's opaque pipeline state, not a GMEM coordinate) **Root cause:** CuTeDSL's JIT appears to constant-fold or not propagate the `kv_coord += 1` increment to the TMA descriptor at runtime. The CUTLASS reference uses the same pattern with a Python int `kv_coord` — unclear why it works there but not here (possibly different CuTeDSL version or loop structure). **Debug shape info:** - `tBgK` before slice: `(((64, 128), 1), Int32(?), Int32(?), Int32(?))` — modes 1,2,3 all dynamic - `tVgV` before slice: `(((64, 128), 1), 1, N, 1)` — mode 2 grows with n (confirmed GMEM iter) - After `(None,0,None,0)`: both become `(((64, 128), 1), N_or_Int32(?))` — 2D ### Files | File | Status | Notes | |------|--------|-------| | `test_fmha_v3_stage_c_full.py` | OK n=128 only | Working real softmax + O normalization | | `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 | | `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks | | `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks | | `test_fmha_v3_12w.py` | OK n=128 only | Identity softmax baseline | ### Current Architecture (6-warp) Warps 0-3: Softmax + Epilogue Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load) ### Target Architecture (12-warp, production) Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty ### CuTeDSL Constraints (hard-won) 1. `vectorize=True` loops: ONLY load/store/print 2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row 3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop 4. TMA cute.copy accepts pipeline state values as coordinates but NOT Python int 5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0 6. `softmax_done_bar` NamedBarrier is reusable across tiles ### Remaining for C' (Production Stage C) 1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2 2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync) 3. Cross-warp reduction for row_max and row_sum 4. Correction warps for multi-tile KV (online O rescale in TMEM) 5. 12-warp layout with separate softmax/correction/epilogue warps ### TMEM Layout Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32) --- ## Key Lessons 1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset. 2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe. 3. **TMEM allocation must be power of 2.** 4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square. 5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access. 6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output. 7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM. 8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this. --- ## Environment - Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU) - venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate` - PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel` - Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4` - vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell) - CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`