DSV4 Inference Kernel

⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️

THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.

After cpasync.tma_partition(), the output GMEM tensor has 8 modes, NOT 4:

tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
                0  1  2  3  4          5  6  7

Mode 4 (0-indexed) is the GMEM tile dimension. The dimension you index with kt to load different K/V tiles.

THE WRONG WAY (what we did — silently loads from tile 0 forever):

# ❌❌❌ THIS ONLY ADDRESSES 4 OF THE 8 MODES ❌❌❌
# Mode 4 (the GMEM tile dim) is NEVER touched by this slice.
# It stays fixed at 0. The TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)]  # ← WRONG! Only 4 modes addressed!

# The copy "works" but kv_coord is meaningless — it indexes mode 1,
# which has size 1, so every coordinate maps to the same TMA descriptor.
cute.copy(tma_k, tBgK[(None, kv_coord)], ...)  # ← kv_coord is ignored!

THE RIGHT WAY (what actually works):

# ✅ Keep ALL 8 modes. Do NOT pre-slice the TMA GMEM tensor.
tBgK = tBgK[(None, None, None, None, None, None, None, None)]

# ✅ Index mode 4 (the GMEM tile dim) in the copy call
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
#                                     ^^ THIS IS MODE 4 — THE GMEM TILE DIM

WHY THIS IS SO INSIDIOUS

  1. No error, no warning. The slice tBgK[(None,None,0,0)] silently drops modes 4-7.
  2. Single-tile (n=128) works perfectly. With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
  3. All multi-tile tests produce "reasonable" output. The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
  4. The strides are all 0. Printing tBgK.layout.stride shows all zeros for TMA tensors. You can't detect the bug from strides alone.
  5. cute.printf shows kv_coord=0. We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.

THE LESSON

ALWAYS print cute.shape(tBgK) after tma_partition. Count the modes. Know which mode is your GMEM tile dimension. Index it explicitly. NEVER pre-slice TMA partition tensors with fewer dimensions than they actually have.

# After tma_partition, ALWAYS verify the shape:
print(f"tBgK shape: {cute.shape(tBgK)}")  # Should show 8 modes
print(f"n_kv_tiles at mode 4: {cute.size(tBgK, mode=[4])}")

# Use full indexing in cute.copy, not a pre-sliced 2D view
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)

IF YOU SKIP THIS AND PRE-SLICE TO 4 MODES, MULTI-TILE TMA WILL BE SILENTLY BROKEN.


Architecture

DSV4 is not MLA. It uses CSA (Compressed Sparse Attention, m=4) and HCA (Heavily Compressed Attention, m=128). KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.

DSV4 inference pipeline — component status
==========================================

Legend:
 [✓] built and tested
 [~] partial — reference or seam exists, native pending
 [✗] to build


 ┌────────────────────────────────────┐
 │ [✗] Embedding + mHC init          │
 │ token embed + n_hc=4 streams      │
 └────────────────┬───────────────────┘
                  │
                  ▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after              │
│                                                                      │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Norms + RoPE             RMSNorm + partial RoPE           │  │
│ │ [✓] Q / KV projection        NVFP4 linears + LoRA             │  │
│ │ [~] Token compressor         CSA m=4 / HCA m=128             │  │
│ │ [✗] Indexer + top-k          CSA only, FP4 QK                 │  │
│ │ [~] FMHA core                QK → online softmax → PV         │  │
│ │                              + SWA branch + sink merge         │  │
│ │ [✓] Output projection        inv RoPE + wo_a grouped + wo_b   │  │
│ └────────────────────────────────────────────────────────────────┘  │
│                                                                      │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Pre-FFN norm              RMSNorm                          │  │
│ │ [✗] Router                    sqrt(softplus) + topk + hash     │  │
│ │ [✓] Routed MoE               fused SwiGLU L1 + L2             │  │
│ │ [✓] Shared expert            NVFP4 single-group GEMM          │  │
│ └────────────────────────────────────────────────────────────────┘  │
└──────────────────────────────────┬───────────────────────────────────┘
                                  │
                                  ▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘

┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management                                             │
│ • state cache: SWA window + uncompressed tail per layer             │
│ • classical paged cache: lcm(m, m) = 128 tokens per block         │
│ • heterogeneous layout per layer                                    │
└──────────────────────────────────────────────────────────────────────┘


Summary
-------
 Built  [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
               shared expert
 Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
               pre-FFN norm
 To build [✗] : 8 — embedding+init, indexer+top-k, router,
               final norm, LM head, MTP, sampler, KV cache

Status (May 22, 2026 — 16:30 UTC)

Stage Status Description
A COMPLETE Q@K^T via tcgen05.mma → TMEM → GMEM
B COMPLETE QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving)
C ⚠️ MULTI-TILE IN PROGRESS Single-tile cos 0.999998. TMA fix: n=256 cos 0.9956. Need O rescale + pipeline cycling.
C' 🔨 IN PROGRESS Multi-tile TMA indexing fix + correction warps. See below.
D TODO Full decode attention: paged KV cache, multi-query, causal mask
E TODO Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge

Package Structure

dsv4/
├── kernels/          Pure GPU code (CuTeDSL @cute.jit, .cu files)
│   ├── gemm/           NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│   ├── attention/      FMHA kernel (stub — extraction is Stage E)
│   ├── compressor/     CSA/HCA token-level compressor
│   ├── decode/         Decode-time attention (sparse, SWA — future)
│   └── cuda/           Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/              PyTorch ↔ kernel bridges
│   ├── quantize.py      BF16 ↔ NVFP4 conversion, scale factors
│   ├── layouts.py       Scale swizzle, gate/up interleave, K-major, offsets
│   ├── gemm_runner.py   Warmup, compile, run grouped/fused GEMMs
│   ├── custom_ops.py    torch.library.custom_op registrations
│   ├── decode_sparse.py native_sparse_decode dispatcher
│   ├── decode_swa.py    native_swa_decode dispatcher
│   ├── rope.py          Forward + inverse RoPE
│   └── topk.py          Python wrapper for sparse_topk_metadata.cu
├── layers/           nn.Module-style components
│   ├── linear.py        Nvfp4Linear
│   ├── grouped_linear.py Nvfp4GroupedLinear
│   ├── moe.py           Nvfp4MoE
│   ├── shared_expert.py Nvfp4SharedExpert
│   ├── mhc.py           mHCLayer
│   └── (stubs: attention, ffn, router, norm, embedding)
├── model/            Model assembly (stubs — Phase 1)
├── cache/            KV cache infra (stubs — Phase 3)
├── loader/           Checkpoint I/O (stubs — Phase 1)
└── reference/        Slow PyTorch oracles (never imported by production code)
    ├── attention.py     RoPE, KV cache, causal attention, SWA
    ├── csa_attention.py CSA/HCA sparse attention
    ├── compressor.py    Compressor PyTorch example
    └── moe_pipeline.py  MoE pipeline reference

Mental model: kernels/ops/layers/model/ (dependency flows left to right). reference/ and loader/ are sidecars.


Active Test Files

FMHA (Stages A/B/C) — in tests/unit/

File Stage Status
test_fmha_v3.py A+B Full QK→identity softmax→PV, cosine 0.999999
test_fmha_v3_12w.py A+B 12-warp QK→PV, cosine 0.999999
test_fmha_v3_stage_c_full.py C Real online softmax + O normalization, cosine 0.993-0.996
test_fmha_v3_stage_c_min.py C 🔨 Early 12-warp pipeline (broken pipeline state)
test_pv64_with_softmax.py B (128,64) PV, single AB pipeline
test_128_128_vdiag.py A+B (128,128) PV baseline
test_qkonly.py A QK with split Q/KV pipelines
test_qk_softmax.py A+B QK + identity softmax, no PV

MoE / GEMM — in tests/unit/

File What
test_cutedsl.py NVFP4 grouped GEMM kernel
cudagraph_test.py Cudagraph capture + replay
layertest.py Per-layer correctness
test_custom_op.py torch.library custom ops
test_compile_custom_op.py Compile + warmup
test_fp4_roundtrip.py BF16 → NVFP4 → BF16 roundtrip
test_interleave.py Gate/up weight interleaving
test_interleave_gemm.py Interleaved GEMM correctness
test_fused_step1.py Fused SwiGLU GEMM

Archived Tests

tests/archive/ contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.


Test Harness

Scripts in tests/ for running tests on the B200 (root@45.76.247.107):

run_test.sh — Run a test in a screen session

# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py

What it does:

  1. Kills any existing kernel-test screen and SIGKILLs all child processes (handles deadlocked GPU procs that ignore SIGHUP)
  2. Deletes the old log file
  3. Starts a new screen -dmS kernel-test running the test
  4. Logs output to /tmp/kernel-test.log
  5. Verifies the screen started

check_log.sh — Check test progress

bash tests/check_log.sh

Shows the log contents and whether the screen is still running.

Local → B200 workflow

# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push

# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py

# 3. Check results
bash tests/check_log.sh

fire_b200_test — One-command local test runner

Lives in ~/.openclaw/workspace/fire_b200_test (NOT in the repo — project-specific tooling).

# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py

What it does:

  1. Auto-commits and pushes any local changes
  2. SSH to B200, pulls, starts run_test.sh in a screen
  3. Polls every 15s until the screen exits
  4. Dumps the full test log to your terminal

This is strictly for the DSV4 NVFP4 kernel project. It hardcodes the B200 IP, repo paths, and git remote.


Stage C: Online Softmax — Multi-Tile In Progress

What We Have

Working real softmax for single KV tile (n=128): cosine 0.999998. Multi-tile TMA indexing fixed (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug. Remaining: O rescale between tiles, pipeline state cycling for n≥384, correction warps.

Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)

After cpasync.tma_partition(), the output GMEM tensor has 8 modes, not 4:

tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
                0  1  2  3  4          5  6  7

Mode 4 is the GMEM tile dimension. Our old pre-slice tBgK[(None, None, 0, 0)] only addressed 4 modes — modes 4-7 were implicitly collapsed to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.

The fix: Do not pre-slice. Index all 8 modes explicitly in cute.copy, putting kt at mode 4:

cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)

Results after fix:

  • n=128: cos 0.999998
  • n=256: cos 0.9956 (lower because no O rescale yet)

Remaining for Multi-Tile

  1. O rescale between tiles: O *= exp2(old_max - new_max) — needed for n=256+ to hit 0.9999
  2. Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages)
  3. Correction warps for production (separate softmax/correction/epilogue)
  4. 12-warp layout

Files

File Status Notes
fmha_v3_stage_c_example10.py 🔨 CURRENT 8-mode TMA, combined K+V pipeline, O rescale, final normalize
test_fmha_v3_stage_c_full.py OK n=128 Working real softmax + O normalization
fmha_v3_stage_c_example1.py BROKEN multi-tile First fix attempt, TMA still loads tile 0
fmha_v3_stage_c_example2.py DEADLOCK Combined K+V barrier, compiles but deadlocks
test_fmha_v3_stage_c2.py DEADLOCK 12-warp pipeline, compiles but deadlocks
test_fmha_v3_12w.py OK n=128 only Identity softmax baseline

Current Architecture (6-warp)

Warps 0-3: Softmax + Epilogue Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load)

Target Architecture (12-warp, production)

Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty

CuTeDSL Constraints (hard-won)

  1. vectorize=True loops: ONLY load/store/print
  2. .reduce(cute.ReductionOp.MAX): reduces ENTIRE C-fragment to scalar — global max, not per-row
  3. cute.arch.fmax: impure for vectorizer — use plain range() loop
  4. tBgK/tVgV have 8 modes after tma_partition — mode 4 is GMEM tile dim, must index all 8 explicitly
  5. tBgK[(None, 0, None, 0)] hardcodes GMEM iteration to tile 0
  6. softmax_done_bar NamedBarrier is reusable across tiles

Remaining for C' (Production Stage C)

  1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
  2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
  3. Cross-warp reduction for row_max and row_sum
  4. Correction warps for multi-tile KV (online O rescale in TMEM)
  5. 12-warp layout with separate softmax/correction/epilogue warps

TMEM Layout

Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)


Key Lessons

  1. NEVER use find_tmem_tensor_col_offset() as TMEM placement. It returns footprint size, not a safe offset.
  2. FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe.
  3. TMEM allocation must be power of 2.
  4. Square hides bugs. (128,128) worked for every wrong approach. Always test non-square.
  5. St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.
  6. First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.
  7. FMHA P store uses QK C-fragment composition, NOT PV A-fragment. Two aliases, same TMEM.
  8. Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip this.

Environment

  • Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM repo: /root/dsv4-nvfp4-workspace/vllm (modified for Blackwell)
  • CUTLASS FMHA reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%