The dequantize_nvfp4 path for shared expert made output WORSE (random Chinese tokens, gibberish) vs NVFP4 GEMM which at least produces 'OK'. The SE NVFP4 GEMM is working; the dequant scale computation was likely wrong. Keeping BF16 router gate (which improved output from 'response' loop to 'OK').
DSV4 Inference Kernel
Production-grade Blackwell SM100 inference kernel for DeepSeek-V4-Pro NVFP4, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).
This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
DSV4 is not MLA
This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions.
The differences that matter:
| MLA (V2/V3) | V4 | |
|---|---|---|
| Compression axis | feature/head dim (per-token latent) | sequence dim (multiple tokens collapsed into one entry) |
| Cache entries per token | one latent per token | one compressed entry per m tokens |
| Attention pattern | dense over all cached latents | hybrid: sparse top-k (CSA) + dense over heavily-compressed (HCA) + sliding window (SWA) |
| Compression rate | n/a (1:1) | m=4 for CSA, m'=128 for HCA |
| Selection | none — all tokens attended | lightning indexer + top-k for CSA |
| Output positional fix | n/a | inverse RoPE on each per-head output |
| Sink merge | n/a | per-head learnable attention sink merged via single softmax over [S_comp, S_swa + sink] |
Cache layout reflects this: per-layer state cache for SWA window + uncompressed tail (used for CSA/HCA compression), plus a classical paged cache holding compressed CSA/HCA entries, with block size = lcm(m, m') = 128 original tokens per block.
DSV4 architecture (paper-side reference)
The bits the kernel implements, with the choices we made for inference.
Per-layer attention type schedule
Flash (43 layers): layers 0-1 = SWA, layers 2..42 alternating CSA/HCA (CSA at layer 2)
Pro (61 layers): layers 0-1 = HCA, layers 2..60 alternating CSA/HCA (CSA at layer 2)
Frozen at construction time per LayerSpec so torch.compile constant-folds the dispatch. Validation in dsv4/model/layer_schedule.py:validate_schedule is loud — wrong schedule = silent garbage.
Compressed Sparse Attention (CSA)
- Compresses every
m=4KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 11–12 of the paper. - Compressed sequence length is
n/m. - Lightning indexer scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps
csa_top_kblocks (512 Flash / 1024 Pro). - Core attention is MQA over the selected blocks + a sliding window branch of
n_win=128raw tokens. - Partial RoPE on the last 64 dims of Q and the compressed K, with inverse RoPE on each per-head output so the per-token contribution carries the correct relative position.
- Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below.
Heavily Compressed Attention (HCA)
- Same compressor concept as CSA but
m'=128, no overlap, dense attention over the (very short) compressed sequence. - No indexer.
- Same partial RoPE + inverse RoPE + sliding window + sink as CSA.
Sliding Window Attention (SWA)
- First two layers of Flash. Pure local attention over the SWA window. No compressed branch, no indexer.
- Cache layout: ring buffer of size
n_winper request in the state cache.
Manifold-Constrained Hyper-Connections (mHC)
- Replaces residual connections. Width-expanded residual stream
(T, n_hc=4, d). - Per-token dynamic
A_l,B_l,C_lmixing matrices generated by a fused 24-output prenorm projection (4 + 4² + 4). A_l = σ(.),C_l = 2σ(.),B_l = SinkhornKnopp(exp(.), t_max=20)to project onto the Birkhoff polytope.pre_block:x_in = A_l @ X_l;post_block:X_next = B_l @ X_l + C_l ⊗ F_out.B_lheld in FP32 for the bmm precision; A/C cast to BF16.
Router
- Two modes, frozen at construction by layer index:
- Hash routing (layers 0–2): deterministic per-token-ID LUT lookup, uniform weights
1/k. - Dense routing (layers 3+):
sqrt(softplus(X @ W_gate))activation, plus learnede_biasfor selection only. Top-k (k=6), renormalize on unbiased activations, multiply byrouted_scaling_factor.
- Hash routing (layers 0–2): deterministic per-token-ID LUT lookup, uniform weights
MoE
- DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token.
- L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down).
- SwiGLU clamping per paper §4.2.3: gate capped at
swiglu_limit=10, linear clamped to[-limit, +limit]. - All weights NVFP4, FP8 E4M3 scales, 16-element microblocks.
Sink merge (D5c — key insight)
The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a single softmax over [S_comp, S_swa + attn_sink].
One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed.
Package structure
dsv4/
├── kernels/ Pure GPU code
│ ├── attention/ Production FMHA — 6-warp TMA multi-tile (.cuh + C-API .cu + op.py + production.py)
│ │ production.py is the entry point used by single_shot_inference.py
│ ├── gemm/ NVFP4 MoE GEMM (grouped, fused_swiglu, dense, scheduler)
│ ├── compressor/ CSA/HCA production compressor (production_compress.py → compressor_reduce.cu)
│ ├── indexer/ CSA indexer (stub; live path is inline in single_shot_inference.py)
│ ├── router/ Dense router decode + activation_topk
│ ├── cuda/ Raw .cu kernels (loader.py compiles on demand)
│ └── cache/ (stub; SWA/flush kernels are in cuda/)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4, scale factor handling, QuantizedActivation
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── rope_cuda.py Forward + inverse RoPE (partial, last 64 dims)
│ └── router.py Router op bridge (dense + hash dispatch)
├── layers/ nn.Module-style components (used by single_shot_inference.py)
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear (output projection)
│ ├── moe.py Nvfp4MoE (routed experts)
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer (Sinkhorn-Knopp, residual mixing)
│ └── router.py Router (dense + hash modes)
├── model/
│ ├── config.py DSV4Config
│ └── sampler.py CUDASampler
├── reference/
│ └── single_shot_PYTORCH_REFERENCE.py PyTorch oracle for layer comparison tests
└── _archive/ Dead Lineage P code (model/dsv4.py, cache/*, layers/{attention,ffn,norm,embedding}, etc.)
Kept for reference; never imported by live code
Live path: single_shot_inference.py → dsv4/layers/* → dsv4/ops/* → dsv4/kernels/**
Attention path: production.py → fmha_multitile_op.py → fmha_multitile_capi.cu → fmha_6warp_tma_multirow_multitile.cuh
Archived (Lineage P): dsv4/model/dsv4.py, dsv4/cache/*, dsv4/layers/{attention,ffn,norm,embedding} — these were the vLLM/sglang integration surface but have 0 importers. See _archive/ if needed.
Workflow & test harness
The non-negotiables
- NEVER edit on the B200. Always: edit locally → commit → push → pull on B200 → test.
- NEVER raw SSH + direct command. Always use the test harness scripts. They handle: killing hung processes, deleting stale logs, screen sessions that survive SSH drops, timeouts for hung kernels, and GPU cleanup.
- ALWAYS verify hd=64 regression (cos ~0.999998) after every FMHA change. If it regresses, the change is wrong. Revert.
- NEVER touch drivers, kernels, firmware, or system packages on the B200.
- NEVER delete test files in
tests/unit/without explicit approval.
Two harnesses: Python and CUDA
| Harness | For | Script | Screen name | Log file |
|---|---|---|---|---|
| Python | test_*.py files |
fire_b200_test |
kernel-test |
/tmp/kernel-test.log |
| CUDA | test_*.cu files |
fire_b200_cuda_test |
cuda-test |
/tmp/cuda-test.log |
Both harnesses follow the same discipline:
- Kill everything first — old screen sessions, hanging GPU processes, stale binaries
- Delete all logs — never debug from a previous run's log
- Clean git + pull — no uncommitted B200 state
- Run in screen — survives SSH drops, has a timeout
- One test at a time — no parallel launches, ever
Python test
# From local machine — auto-pushes, runs, polls, dumps log
# DEFAULT timeout: 600s (10 min). Override with all 4 args:
~/.openclaw/workspace/fire_b200_test <test_file> [screen_name] [log_file] [timeout_sec]
# Examples:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py
~/.openclaw/workspace/fire_b200_test tests/unit/test_degeneration_2_mhc_falsify.py kernel-test /tmp/kernel-test.log 1800
CUDA test
# From local machine — compiles with nvcc, runs, polls, dumps log
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_fmha_sm100_standalone.cu
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30 # custom timeout
Check on a running test
# Check CUDA test log + screen status
~/.openclaw/workspace/check_b200_cuda
~/.openclaw/workspace/check_b200_cuda kill # kill a hung test
# Check Python test — SSH to B200 and tail the log:
ssh root@<B200> tail -f /tmp/kernel-test.log
Manual B200 cycle (emergency only)
ssh root@<B200>
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_<...>.py
bash tests/check_log.sh
⚠️ Test harness gotchas (READ THIS — cost real time)
-
The timeout is the 4th argument, not the 2nd.
- WRONG:
fire_b200_test test.py 1800← this makes1800the SCREEN NAME - RIGHT:
fire_b200_test test.py kernel-test /tmp/kernel-test.log 1800 - When you pass just a number as the 2nd arg, the screen gets a numeric name
and the harness can't kill the old
kernel-testscreen on the next run. - Always pass all 4 args when you need a custom timeout.
- WRONG:
-
After a timeout, the harness kills the screen but NOT the GPU process.
- The
timeoutcommand inside screen kills the shell, but CUDA processes survive. - Before re-running, check:
ssh root@<B200> nvidia-smi --query-compute-apps=pid --format=csv,noheader - Kill stale processes:
kill -9 <pid>for each GPU process listed - Or:
for pid in $(nvidia-smi --query-compute-apps=pid --format=csv,noheader); do kill -9 $pid; done
- The
-
After an OOM or crash, stale GPU processes WILL be left behind.
- Always check
nvidia-smibefore running a new test after a failure. - The harness kills
python.*test_andpython.*inferenceprocs, but if the process name doesn't match the pattern, it survives.
- Always check
-
Single-shot tests MUST use the harness too.
single_shot_inference.pyis NOT a unit test, but it MUST be run via the harness.- WRONG: ssh to B200 and run
python single_shot_inference.pydirectly - RIGHT:
fire_b200_test single_shot_inference.py kernel-test /tmp/kernel-test.log 1800 -- --max-tokens 512 - Extra args after
--are passed to the Python script. - If the harness can't handle your use case, FIX THE HARNESS, don't bypass it.
-
Weight loading + CuTeDSL compilation takes 5-10 minutes.
- First FMHA call triggers JIT compile of CuTeDSL kernels.
- This is EXPECTED. Do NOT kill the process because it "seems stuck".
- Use 1800s (30 min) timeout for full-model tests.
-
The screen name must match between runs.
- The harness kills the old screen by name. If you used a different name last time, the old screen survives and holds GPU memory.
- Always use
kernel-testfor Python tests andcuda-testfor CUDA tests. - If you accidentally used a numeric screen name, clean up manually:
ssh root@<B200> screen -S <wrong_name> -X quit
Environment
- B200 access: see
MEMORY.md(not committed). - venv:
source /root/dsv4-nvfp4-workspace/venv/bin/activate - PYTHONPATH:
/root/dsv4-nvfp4-workspace/kernel - Model:
/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4 - vLLM (modified for Blackwell):
/root/dsv4-nvfp4-workspace/vllm - CUTLASS FMHA reference:
/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py - Local CUTLASS clone:
/home/openclaw/dev/cutlass
CuTeDSL constraints (read every session)
These are surface-level traps. Get them wrong and the kernel silently produces garbage, NaN, or "weakly congruent" at JIT compile time.
-
TMA partition tensors have 4 modes:
(((64,128),1), ?, KV_tiles, ?).(None, 0, None, 0)keeps mode 2 (KV tiles) free;[None, kt]indexes it.(None, None, 0, 0)silently pins mode 2 to 0 — multi-tile loads break invisibly. -
vectorize=Trueloops accept only load/store/print. Nofmax, nocmpf, no inner loops, no carry across iterations. -
.reduce(cute.ReductionOp.MAX)reduces the entire C-fragment to a scalar — global, not per-row. Use a plainrange()loop withcute.arch.fmaxfor per-row max. -
cute.arch.fmaxis impure for the vectorizer. Use it inside plainrange(), never insidevectorize=True. -
Hand-constructed TMEM atoms corrupt data on round-trip. Independently-built
Ld32x32bOp+St32x32bOpatoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms fromepilogue_tmem_copy_and_partition/epilogue_smem_copy_and_partitionfor one-way trips. -
CuTeDSL
ifblocks are separate MLIR regions. Variables defined inside oneifare not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching. -
Guard dead code with
const_expr. CuTeDSL compiles both branches of Pythonif. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; withoutconst_expr, MLIR chews on it. -
tma_partitionandflat_dividemay not survive insideif warp_idxblocks. Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel callstma_partitioninside the epilogue warp'sif, so this constraint may depend on context — print and verify.) -
TMEM allocation must be a power of 2. Round up after summing column requirements.
-
compositionvslogical_divideproduce different layouts even when re-tiling the same tensor.correction_rescaleusescomposition,correction_epiloguseslogical_divide. Copy atoms must match the tensor layout they were created with. -
After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN. -
St32x32bOpmust use Float32, not BFloat16. BFloat16 causes illegal memory access. -
First PV must have
ACCUMULATE=False. Otherwise adds uninitialized TMEM contents to the output. -
find_tmem_tensor_col_offset()returns footprint size, not a safe offset. Never use it as a TMEM placement. -
FMHA never trusts DLPack tensor layouts. Reconstruct V as
(hd, s_k)MN-major inside CuTe via explicitmake_tensor+make_layout.
Lessons learned (the gold — read every session)
These cost real days to learn. They are listed in priority of how easy they are to repeat.
Layout & TMA
- TMA partition mode ordering (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.7–0.9, never NaN — so you can ship it without knowing.
- Square hides bugs. (128,128) worked for every wrong approach to PV. Always test non-square shapes early.
- Print the shapes always. Reasoning about TMEM layouts or TMA mode counts without running
cute.printf(cute.shape(t))inside@cute.kernelis how every multi-day debug starts. Shapes are ground truth. qk_mma_tilerK-dim must equalhead_dim, not the MMA instruction's K sub-tile size. Hardcodingqk_ik * 4 = 64was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128.
TMEM
- Never assume TMEM round-trips are safe. Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP.
- FMHA P store uses QK C-fragment composition, not PV A-fragment. Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
- Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip the dual view.
- TMEM round-trip mismatch with
epilogue_tma_store:epilogue_tma_storereads O from TMEM usingget_tmem_load_op's layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error. - The correction-epilog pattern is the fix. TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results.
CuTeDSL & MLIR
- CuTeDSL
ifblocks create separate MLIR regions. Variables defined inif not use_smem_p:and read in anotherif not use_smem_p:inside aforinside anif warp_idx < mma_warp_id:are not visible. Define unconditionally before any branching. - CuTeDSL compiles both branches of Python
if. Wrap mode-specific dead code inconst_expr(condition)to eliminate it. Critical for O rescale (n_kv_tiles > 1), LSE compute (not normalize), SMEM-P path. - CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512. Both unrolled (Python
range) and runtime (cutlass.range unroll=1) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours. - Don't mix Python loops and pipeline ops. Python
forunrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefercutlass.range(unroll=1)for pipeline loops.
Math & merging
- External k_sub merge is mathematically impossible. You cannot merge
softmax(Q_k0 @ K_k0^T) @ Vandsoftmax(Q_k1 @ K_k1^T) @ Vintosoftmax(Q @ K^T) @ V. k_sub partitions are additive in logit space (S = S_0 + S_1); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over different token sets (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub. - D5 multi-tile KV merge IS valid. Per-segment LSE + the formula
O = Σ exp(lse_i) · O_i / Σ exp(lse_i)works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix. - Sink merge = single softmax over
[S_comp, S_swa + attn_sink]. The two-branch weighted merge formula in the paper is mathematically equivalent to addingattn_sinkas a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d.
Numerics
- Always test at hd=64 first. If the proven TMEM-P path regresses, nothing else matters.
St32x32bOpmust be Float32, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.)- First PV
ACCUMULATE=False. Otherwise sums uninitialized TMEM into the output and you see ~50% error.
Workflow
- Never edit on the B200. Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost.
- Print shapes inside
@cute.kernelat trace time.print(f"tBgK shape: {cute.shape(tBgK)}")runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL.
SMEM budget
pv_n_tileis the easiest SMEM knob. At hd > 256, reducingpv_n_tilefrom 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling.kv_stageis the second-easiest. Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB.- SMEM budget at various hd (with
pv_n_tile=256for hd≤256,pv_n_tile=128for hd>256,kv_stage=2for hd≤128 else 1):
| hd | sQ | sK | sV | sP | sC | Total | Limit |
|---|---|---|---|---|---|---|---|
| 64 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
| 128 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
| 256 | 64 KB | 64 KB | 64 KB | 0* | 32 KB | 224 KB | 232 KB |
| 512 | 64 KB | 64 KB | 32 KB | 0* | 32 KB | 192 KB | 232 KB |
*TMEM-P path: sP allocation skipped via const_expr conditional.
Reference
- DeepSeek V4 paper:
DeepSeek_V4.pdfin the repo root. - DeepGEMM (V4-aligned reference kernels): https://github.com/deepseek-ai/DeepGEMM
- CUTLASS FMHA reference:
/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py(B200) or/home/openclaw/dev/cutlass(local). - Reference oracles:
dsv4/reference/(PyTorch FP32 — slow, never imported by production code).