The tma_partition output has 8 TMA coordinate dimensions, not 4. The Python-visible shape shows 4 modes, but the TMA descriptor uses 8 coordinates. Without the 8-None no-op pre-slice, modes 4-7 are collapsed and the GMEM tile axis (mode 4) is pinned to 0. Pattern that works (confirmed on B200 at n=256 in diag test): tBgK = tBgK[(None,None,None,None,None,None,None,None)] # open 8D cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...) The old 4-mode indexing tBgK[(None,None,kt,0)] fails with 'rank mismatch: got 2 and 1' because slicing a 4-mode tensor produces wrong rank for the TMA coordinate space. Matches working diag test test_fmha_v3_diag.py exactly.
DSV4 Inference Kernel
⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Coordinate Space ⚠️⚠️⚠️
THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.
After cpasync.tma_partition(), the output GMEM tensor has 8 TMA coordinate dimensions:
tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
Mode 4 is the GMEM tile dimension. The dimension you index with kt to load different K/V tiles.
The Python-visible shape only shows 4 modes, but the TMA coordinate space is 8-dimensional. You MUST apply an 8-None no-op pre-slice to open the full coordinate space before indexing.
THE WRONG WAY (what we did — silently loads from tile 0 forever):
# ❌❌❌ 4-MODE PRE-SLICE COLLAPSES THE GMEM TILE AXIS ❌❌❌
# The (None, None, 0, 0) slice only addresses 4 of 8 TMA coord dims.
# Modes 4-7 get collapsed to coordinate 0. TMA ALWAYS reads tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG!
# The copy "works" but kv_coord indexes mode 1 (size 1), so
# every coordinate maps to the same TMA descriptor.
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord is ignored!
THE RIGHT WAY (what actually works — confirmed on B200 at n=256):
# ✅ 8-None no-op pre-slice opens the full TMA coordinate space
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
# ✅ Index mode 4 (the GMEM tile dim) in the copy call
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
# ^^ MODE 4 — THE GMEM TILE DIM
WHY THIS IS SO INSIDIOUS
- No error, no warning. The slice
tBgK[(None,None,0,0)]silently collapses modes 4-7. - Single-tile (n=128) works perfectly. With only 1 KV tile, mode 4 is size 1, so the bug is invisible.
- All multi-tile tests produce "reasonable" output. The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
- The strides are all 0. Printing
tBgK.layout.strideshows all zeros for TMA tensors. You can't detect the bug from strides alone. cute.printfshowskv_coord=0. We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
THE LESSON
TMA tensors use a coordinate space, not a pointer space. The TMA instruction at PTX level takes integer coordinates (crd0, crd1, crd2, ...), not pointers. CuTeDSL's tma_partition produces a tensor whose layout maps logical coordinates to TMA coordinate tuples. When you pre-slice with fewer dimensions than the TMA descriptor expects, the extra coordinate dimensions get collapsed to 0.
The 8-None no-op pre-slice is mandatory for multi-tile TMA. Without it, the GMEM tile axis (mode 4) is invisible and unindexable.
# After tma_partition, ALWAYS apply the 8-None no-op pre-slice:
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
tVgV = tVgV[(None, None, None, None, None, None, None, None)]
# Then index mode 4 in cute.copy:
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
IF YOU SKIP THE 8-NONE PRE-SLICE, MULTI-TILE TMA WILL BE SILENTLY BROKEN.
Architecture
DSV4 is not MLA. It uses CSA (Compressed Sparse Attention, m=4) and HCA (Heavily Compressed Attention, m′=128). KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
DSV4 inference pipeline — component status
==========================================
Legend:
[✓] built and tested
[~] partial — reference or seam exists, native pending
[✗] to build
┌────────────────────────────────────┐
│ [✗] Embedding + mHC init │
│ token embed + n_hc=4 streams │
└────────────────┬───────────────────┘
│
▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 0–1 of Pro, alternating CSA / HCA after │
│ │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [~] Token compressor CSA m=4 / HCA m′=128 │ │
│ │ [✗] Indexer + top-k CSA only, FP4 QK │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Pre-FFN norm RMSNorm │ │
│ │ [✗] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────┬───────────────────────────────────┘
│
▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘
┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management │
│ • state cache: SWA window + uncompressed tail per layer │
│ • classical paged cache: lcm(m, m′) = 128 tokens per block │
│ • heterogeneous layout per layer │
└──────────────────────────────────────────────────────────────────────┘
Summary
-------
Built [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert
Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
pre-FFN norm
To build [✗] : 8 — embedding+init, indexer+top-k, router,
final norm, LM head, MTP, sampler, KV cache
Status (May 22, 2026 — 16:30 UTC)
| Stage | Status | Description |
|---|---|---|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ⚠️ MULTI-TILE IN PROGRESS | Single-tile cos 0.999998. TMA fix: n=256 cos 0.9956. Need O rescale + pipeline cycling. |
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
Package Structure
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel (stub — extraction is Stage E)
│ ├── compressor/ CSA/HCA token-level compressor
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
│ └── topk.py Python wrapper for sparse_topk_metadata.cu
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
│ └── (stubs: attention, ffn, router, norm, embedding)
├── model/ Model assembly (stubs — Phase 1)
├── cache/ KV cache infra (stubs — Phase 3)
├── loader/ Checkpoint I/O (stubs — Phase 1)
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
├── compressor.py Compressor PyTorch example
└── moe_pipeline.py MoE pipeline reference
Mental model: kernels/ → ops/ → layers/ → model/ (dependency flows left to right). reference/ and loader/ are sidecars.
Active Test Files
FMHA (Stages A/B/C) — in tests/unit/
| File | Stage | Status |
|---|---|---|
test_fmha_v3.py |
A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
test_fmha_v3_12w.py |
A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
test_fmha_v3_stage_c_full.py |
C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 |
test_fmha_v3_stage_c_min.py |
C | 🔨 Early 12-warp pipeline (broken pipeline state) |
test_pv64_with_softmax.py |
B | ✅ (128,64) PV, single AB pipeline |
test_128_128_vdiag.py |
A+B | ✅ (128,128) PV baseline |
test_qkonly.py |
A | ✅ QK with split Q/KV pipelines |
test_qk_softmax.py |
A+B | ✅ QK + identity softmax, no PV |
MoE / GEMM — in tests/unit/
| File | What |
|---|---|
test_cutedsl.py |
NVFP4 grouped GEMM kernel |
cudagraph_test.py |
Cudagraph capture + replay |
layertest.py |
Per-layer correctness |
test_custom_op.py |
torch.library custom ops |
test_compile_custom_op.py |
Compile + warmup |
test_fp4_roundtrip.py |
BF16 → NVFP4 → BF16 roundtrip |
test_interleave.py |
Gate/up weight interleaving |
test_interleave_gemm.py |
Interleaved GEMM correctness |
test_fused_step1.py |
Fused SwiGLU GEMM |
Archived Tests
tests/archive/ contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.
Test Harness
Scripts in tests/ for running tests on the B200 (root@45.76.247.107):
run_test.sh — Run a test in a screen session
# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py
What it does:
- Kills any existing
kernel-testscreen and SIGKILLs all child processes (handles deadlocked GPU procs that ignore SIGHUP) - Deletes the old log file
- Starts a new
screen -dmS kernel-testrunning the test - Logs output to
/tmp/kernel-test.log - Verifies the screen started
check_log.sh — Check test progress
bash tests/check_log.sh
Shows the log contents and whether the screen is still running.
Local → B200 workflow
# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push
# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
# 3. Check results
bash tests/check_log.sh
fire_b200_test — One-command local test runner
Lives in ~/.openclaw/workspace/fire_b200_test (NOT in the repo — project-specific tooling).
# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
What it does:
- Auto-commits and pushes any local changes
- SSH to B200, pulls, starts
run_test.shin a screen - Polls every 15s until the screen exits
- Dumps the full test log to your terminal
This is strictly for the DSV4 NVFP4 kernel project. It hardcodes the B200 IP, repo paths, and git remote.
Stage C: Online Softmax — Multi-Tile In Progress
What We Have
Working real softmax for single KV tile (n=128): cosine 0.999998. Multi-tile TMA indexing fixed (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug. Remaining: O rescale between tiles, pipeline state cycling for n≥384, correction warps.
Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
After cpasync.tma_partition(), the output GMEM tensor has 8 TMA coordinate dimensions:
tBgK TMA coord space: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
0 1 2 3 4 5 6 7
Mode 4 is the GMEM tile dimension. Our old pre-slice tBgK[(None, None, 0, 0)] collapsed modes 4-7 to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
The fix: 8-None no-op pre-slice + 8-mode indexing with kt at mode 4:
tBgK = tBgK[(None, None, None, None, None, None, None, None)]
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
Results after fix:
- n=128: cos 0.999998 ✅
- n=256: cos 0.9956 ✅ (lower because no O rescale yet)
Remaining for Multi-Tile
- O rescale between tiles:
O *= exp2(old_max - new_max)— needed for n=256+ to hit 0.9999 - Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages)
- Correction warps for production (separate softmax/correction/epilogue)
- 12-warp layout
Files
| File | Status | Notes |
|---|---|---|
fmha_v3_stage_c_example10.py |
🔨 CURRENT | 8-mode TMA, combined K+V pipeline, O rescale, final normalize |
test_fmha_v3_stage_c_full.py |
OK n=128 | Working real softmax + O normalization |
fmha_v3_stage_c_example1.py |
BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
fmha_v3_stage_c_example2.py |
DEADLOCK | Combined K+V barrier, compiles but deadlocks |
test_fmha_v3_stage_c2.py |
DEADLOCK | 12-warp pipeline, compiles but deadlocks |
test_fmha_v3_12w.py |
OK n=128 only | Identity softmax baseline |
Current Architecture (6-warp)
Warps 0-3: Softmax + Epilogue Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load)
Target Architecture (12-warp, production)
Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty
CuTeDSL Constraints (hard-won)
vectorize=Trueloops: ONLY load/store/print.reduce(cute.ReductionOp.MAX): reduces ENTIRE C-fragment to scalar — global max, not per-rowcute.arch.fmax: impure for vectorizer — use plainrange()looptBgK/tVgVhave 8 TMA coord dims after tma_partition — 8-None no-op pre-slice required, mode 4 is GMEM tile dimtBgK[(None, 0, None, 0)]hardcodes GMEM iteration to tile 0softmax_done_barNamedBarrier is reusable across tiles
Remaining for C' (Production Stage C)
- Fix multi-tile TMA — combined K+V barrier or kh.count // 2
- Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
- Cross-warp reduction for row_max and row_sum
- Correction warps for multi-tile KV (online O rescale in TMEM)
- 12-warp layout with separate softmax/correction/epilogue warps
TMEM Layout
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)
Key Lessons
- NEVER use
find_tmem_tensor_col_offset()as TMEM placement. It returns footprint size, not a safe offset. - FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe.
- TMEM allocation must be power of 2.
- Square hides bugs. (128,128) worked for every wrong approach. Always test non-square.
- St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.
- First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.
- FMHA P store uses QK C-fragment composition, NOT PV A-fragment. Two aliases, same TMEM.
- Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip this.
Environment
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
- venv:
source /root/dsv4-nvfp4-workspace/venv/bin/activate - PYTHONPATH:
/root/dsv4-nvfp4-workspace/kernel - Model:
/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4 - vLLM repo:
/root/dsv4-nvfp4-workspace/vllm(modified for Blackwell) - CUTLASS FMHA reference:
/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py