biondizzle ffa7842b58 Fix dense router: run GEMM in BF16, convert to FP32 only for activation
hidden_states.float() and gate_bf16.T.float() create new FP32 tensors
during CUDA graph capture, which is not graph-capturable.

Fix: run the linear in BF16 (Blackwell tensor cores handle BF16 natively),
then convert only the output logits to FP32 for numerical stability
in sqrt(softplus). The single logits.float() is graph-capturable
because it's a unary op with a pre-existing output buffer.
2026-06-04 04:49:08 +00:00
2026-06-03 10:53:41 +00:00
2026-05-28 08:59:19 +00:00
2026-05-19 09:37:38 +00:00

DSV4 Inference Kernel

Production-grade Blackwell SM100 inference kernel for DeepSeek-V4-Pro NVFP4, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).

This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.


DSV4 is not MLA

This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions.

The differences that matter:

MLA (V2/V3) V4
Compression axis feature/head dim (per-token latent) sequence dim (multiple tokens collapsed into one entry)
Cache entries per token one latent per token one compressed entry per m tokens
Attention pattern dense over all cached latents hybrid: sparse top-k (CSA) + dense over heavily-compressed (HCA) + sliding window (SWA)
Compression rate n/a (1:1) m=4 for CSA, m'=128 for HCA
Selection none — all tokens attended lightning indexer + top-k for CSA
Output positional fix n/a inverse RoPE on each per-head output
Sink merge n/a per-head learnable attention sink merged via single softmax over [S_comp, S_swa + sink]

Cache layout reflects this: per-layer state cache for SWA window + uncompressed tail (used for CSA/HCA compression), plus a classical paged cache holding compressed CSA/HCA entries, with block size = lcm(m, m') = 128 original tokens per block.


DSV4 architecture (paper-side reference)

The bits the kernel implements, with the choices we made for inference.

Per-layer attention type schedule

Flash (43 layers): layers 0-1 = SWA, layers 2..42 alternating CSA/HCA (CSA at layer 2)
Pro   (61 layers): layers 0-1 = HCA, layers 2..60 alternating CSA/HCA (CSA at layer 2)

Frozen at construction time per LayerSpec so torch.compile constant-folds the dispatch. Validation in dsv4/model/layer_schedule.py:validate_schedule is loud — wrong schedule = silent garbage.

Compressed Sparse Attention (CSA)

  • Compresses every m=4 KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 1112 of the paper.
  • Compressed sequence length is n/m.
  • Lightning indexer scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps csa_top_k blocks (512 Flash / 1024 Pro).
  • Core attention is MQA over the selected blocks + a sliding window branch of n_win=128 raw tokens.
  • Partial RoPE on the last 64 dims of Q and the compressed K, with inverse RoPE on each per-head output so the per-token contribution carries the correct relative position.
  • Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below.

Heavily Compressed Attention (HCA)

  • Same compressor concept as CSA but m'=128, no overlap, dense attention over the (very short) compressed sequence.
  • No indexer.
  • Same partial RoPE + inverse RoPE + sliding window + sink as CSA.

Sliding Window Attention (SWA)

  • First two layers of Flash. Pure local attention over the SWA window. No compressed branch, no indexer.
  • Cache layout: ring buffer of size n_win per request in the state cache.

Manifold-Constrained Hyper-Connections (mHC)

  • Replaces residual connections. Width-expanded residual stream (T, n_hc=4, d).
  • Per-token dynamic A_l, B_l, C_l mixing matrices generated by a fused 24-output prenorm projection (4 + 4² + 4).
  • A_l = σ(.), C_l = 2σ(.), B_l = SinkhornKnopp(exp(.), t_max=20) to project onto the Birkhoff polytope.
  • pre_block: x_in = A_l @ X_l; post_block: X_next = B_l @ X_l + C_l ⊗ F_out.
  • B_l held in FP32 for the bmm precision; A/C cast to BF16.

Router

  • Two modes, frozen at construction by layer index:
    • Hash routing (layers 02): deterministic per-token-ID LUT lookup, uniform weights 1/k.
    • Dense routing (layers 3+): sqrt(softplus(X @ W_gate)) activation, plus learned e_bias for selection only. Top-k (k=6), renormalize on unbiased activations, multiply by routed_scaling_factor.

MoE

  • DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token.
  • L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down).
  • SwiGLU clamping per paper §4.2.3: gate capped at swiglu_limit=10, linear clamped to [-limit, +limit].
  • All weights NVFP4, FP8 E4M3 scales, 16-element microblocks.

Sink merge (D5c — key insight)

The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a single softmax over [S_comp, S_swa + attn_sink].

One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed.


Package structure

dsv4/
├── kernels/          Pure GPU code
│   ├── attention/      Production FMHA — 6-warp TMA multi-tile (.cuh + C-API .cu + op.py + production.py)
│   │                     production.py is the entry point used by single_shot_inference.py
│   ├── gemm/           NVFP4 MoE GEMM (grouped, fused_swiglu, dense, scheduler)
│   ├── compressor/     CSA/HCA production compressor (production_compress.py → compressor_reduce.cu)
│   ├── indexer/        CSA indexer (stub; live path is inline in single_shot_inference.py)
│   ├── router/         Dense router decode + activation_topk
│   ├── cuda/           Raw .cu kernels (loader.py compiles on demand)
│   └── cache/          (stub; SWA/flush kernels are in cuda/)
├── ops/              PyTorch ↔ kernel bridges
│   ├── quantize.py      BF16 ↔ NVFP4, scale factor handling, QuantizedActivation
│   ├── layouts.py       Scale swizzle, gate/up interleave, K-major, offsets
│   ├── gemm_runner.py   Warmup, compile, run grouped/fused GEMMs
│   ├── custom_ops.py    torch.library.custom_op registrations
│   ├── rope_cuda.py     Forward + inverse RoPE (partial, last 64 dims)
│   └── router.py        Router op bridge (dense + hash dispatch)
├── layers/           nn.Module-style components (used by single_shot_inference.py)
│   ├── linear.py        Nvfp4Linear
│   ├── grouped_linear.py Nvfp4GroupedLinear (output projection)
│   ├── moe.py           Nvfp4MoE (routed experts)
│   ├── shared_expert.py Nvfp4SharedExpert
│   ├── mhc.py           mHCLayer (Sinkhorn-Knopp, residual mixing)
│   └── router.py        Router (dense + hash modes)
├── model/
│   ├── config.py        DSV4Config
│   └── sampler.py       CUDASampler
├── reference/
│   └── single_shot_PYTORCH_REFERENCE.py  PyTorch oracle for layer comparison tests
└── _archive/         Dead Lineage P code (model/dsv4.py, cache/*, layers/{attention,ffn,norm,embedding}, etc.)
                      Kept for reference; never imported by live code

Live path: single_shot_inference.pydsv4/layers/*dsv4/ops/*dsv4/kernels/**

Attention path: production.pyfmha_multitile_op.pyfmha_multitile_capi.cufmha_6warp_tma_multirow_multitile.cuh

Archived (Lineage P): dsv4/model/dsv4.py, dsv4/cache/*, dsv4/layers/{attention,ffn,norm,embedding} — these were the vLLM/sglang integration surface but have 0 importers. See _archive/ if needed.


Workflow & test harness

The non-negotiables

  • NEVER edit on the B200. Always: edit locally → commit → push → pull on B200 → test.
  • NEVER raw SSH + direct command. Always use the test harness scripts. They handle: killing hung processes, deleting stale logs, screen sessions that survive SSH drops, timeouts for hung kernels, and GPU cleanup.
  • ALWAYS verify hd=64 regression (cos ~0.999998) after every FMHA change. If it regresses, the change is wrong. Revert.
  • NEVER touch drivers, kernels, firmware, or system packages on the B200.
  • NEVER delete test files in tests/unit/ without explicit approval.

Two harnesses: Python and CUDA

Harness For Script Screen name Log file
Python test_*.py files fire_b200_test kernel-test /tmp/kernel-test.log
CUDA test_*.cu files fire_b200_cuda_test cuda-test /tmp/cuda-test.log

Both harnesses follow the same discipline:

  1. Kill everything first — old screen sessions, hanging GPU processes, stale binaries
  2. Delete all logs — never debug from a previous run's log
  3. Clean git + pull — no uncommitted B200 state
  4. Run in screen — survives SSH drops, has a timeout
  5. One test at a time — no parallel launches, ever

Python test

# From local machine — auto-pushes, runs, polls, dumps log
# DEFAULT timeout: 600s (10 min). Override with all 4 args:
~/.openclaw/workspace/fire_b200_test <test_file> [screen_name] [log_file] [timeout_sec]

# Examples:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py
~/.openclaw/workspace/fire_b200_test tests/unit/test_degeneration_2_mhc_falsify.py kernel-test /tmp/kernel-test.log 1800

CUDA test

# From local machine — compiles with nvcc, runs, polls, dumps log
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_fmha_sm100_standalone.cu
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30   # custom timeout

Check on a running test

# Check CUDA test log + screen status
~/.openclaw/workspace/check_b200_cuda
~/.openclaw/workspace/check_b200_cuda kill   # kill a hung test

# Check Python test — SSH to B200 and tail the log:
ssh root@<B200> tail -f /tmp/kernel-test.log

Manual B200 cycle (emergency only)

ssh root@<B200>
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_<...>.py
bash tests/check_log.sh

⚠️ Test harness gotchas (READ THIS — cost real time)

  1. The timeout is the 4th argument, not the 2nd.

    • WRONG: fire_b200_test test.py 1800 ← this makes 1800 the SCREEN NAME
    • RIGHT: fire_b200_test test.py kernel-test /tmp/kernel-test.log 1800
    • When you pass just a number as the 2nd arg, the screen gets a numeric name and the harness can't kill the old kernel-test screen on the next run.
    • Always pass all 4 args when you need a custom timeout.
  2. After a timeout, the harness kills the screen but NOT the GPU process.

    • The timeout command inside screen kills the shell, but CUDA processes survive.
    • Before re-running, check: ssh root@<B200> nvidia-smi --query-compute-apps=pid --format=csv,noheader
    • Kill stale processes: kill -9 <pid> for each GPU process listed
    • Or: for pid in $(nvidia-smi --query-compute-apps=pid --format=csv,noheader); do kill -9 $pid; done
  3. After an OOM or crash, stale GPU processes WILL be left behind.

    • Always check nvidia-smi before running a new test after a failure.
    • The harness kills python.*test_ and python.*inference procs, but if the process name doesn't match the pattern, it survives.
  4. Single-shot tests MUST use the harness too.

    • single_shot_inference.py is NOT a unit test, but it MUST be run via the harness.
    • WRONG: ssh to B200 and run python single_shot_inference.py directly
    • RIGHT: fire_b200_test single_shot_inference.py kernel-test /tmp/kernel-test.log 1800 -- --max-tokens 512
    • Extra args after -- are passed to the Python script.
    • If the harness can't handle your use case, FIX THE HARNESS, don't bypass it.
  5. Weight loading + CuTeDSL compilation takes 5-10 minutes.

    • First FMHA call triggers JIT compile of CuTeDSL kernels.
    • This is EXPECTED. Do NOT kill the process because it "seems stuck".
    • Use 1800s (30 min) timeout for full-model tests.
  6. The screen name must match between runs.

    • The harness kills the old screen by name. If you used a different name last time, the old screen survives and holds GPU memory.
    • Always use kernel-test for Python tests and cuda-test for CUDA tests.
    • If you accidentally used a numeric screen name, clean up manually: ssh root@<B200> screen -S <wrong_name> -X quit

Environment

  • B200 access: see MEMORY.md (not committed).
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM (modified for Blackwell): /root/dsv4-nvfp4-workspace/vllm
  • CUTLASS FMHA reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
  • Local CUTLASS clone: /home/openclaw/dev/cutlass

CuTeDSL constraints (read every session)

These are surface-level traps. Get them wrong and the kernel silently produces garbage, NaN, or "weakly congruent" at JIT compile time.

  1. TMA partition tensors have 4 modes: (((64,128),1), ?, KV_tiles, ?). (None, 0, None, 0) keeps mode 2 (KV tiles) free; [None, kt] indexes it. (None, None, 0, 0) silently pins mode 2 to 0 — multi-tile loads break invisibly.

  2. vectorize=True loops accept only load/store/print. No fmax, no cmpf, no inner loops, no carry across iterations.

  3. .reduce(cute.ReductionOp.MAX) reduces the entire C-fragment to a scalar — global, not per-row. Use a plain range() loop with cute.arch.fmax for per-row max.

  4. cute.arch.fmax is impure for the vectorizer. Use it inside plain range(), never inside vectorize=True.

  5. Hand-constructed TMEM atoms corrupt data on round-trip. Independently-built Ld32x32bOp + St32x32bOp atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from epilogue_tmem_copy_and_partition / epilogue_smem_copy_and_partition for one-way trips.

  6. CuTeDSL if blocks are separate MLIR regions. Variables defined inside one if are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching.

  7. Guard dead code with const_expr. CuTeDSL compiles both branches of Python if. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; without const_expr, MLIR chews on it.

  8. tma_partition and flat_divide may not survive inside if warp_idx blocks. Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel calls tma_partition inside the epilogue warp's if, so this constraint may depend on context — print and verify.)

  9. TMEM allocation must be a power of 2. Round up after summing column requirements.

  10. composition vs logical_divide produce different layouts even when re-tiling the same tensor. correction_rescale uses composition, correction_epilog uses logical_divide. Copy atoms must match the tensor layout they were created with.

  11. After every P store to TMEM, call cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.

  12. St32x32bOp must use Float32, not BFloat16. BFloat16 causes illegal memory access.

  13. First PV must have ACCUMULATE=False. Otherwise adds uninitialized TMEM contents to the output.

  14. find_tmem_tensor_col_offset() returns footprint size, not a safe offset. Never use it as a TMEM placement.

  15. FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe via explicit make_tensor + make_layout.


Lessons learned (the gold — read every session)

These cost real days to learn. They are listed in priority of how easy they are to repeat.

Layout & TMA

  • TMA partition mode ordering (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.70.9, never NaN — so you can ship it without knowing.
  • Square hides bugs. (128,128) worked for every wrong approach to PV. Always test non-square shapes early.
  • Print the shapes always. Reasoning about TMEM layouts or TMA mode counts without running cute.printf(cute.shape(t)) inside @cute.kernel is how every multi-day debug starts. Shapes are ground truth.
  • qk_mma_tiler K-dim must equal head_dim, not the MMA instruction's K sub-tile size. Hardcoding qk_ik * 4 = 64 was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128.

TMEM

  • Never assume TMEM round-trips are safe. Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP.
  • FMHA P store uses QK C-fragment composition, not PV A-fragment. Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
  • Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip the dual view.
  • TMEM round-trip mismatch with epilogue_tma_store: epilogue_tma_store reads O from TMEM using get_tmem_load_op's layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error.
  • The correction-epilog pattern is the fix. TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results.

CuTeDSL & MLIR

  • CuTeDSL if blocks create separate MLIR regions. Variables defined in if not use_smem_p: and read in another if not use_smem_p: inside a for inside an if warp_idx < mma_warp_id: are not visible. Define unconditionally before any branching.
  • CuTeDSL compiles both branches of Python if. Wrap mode-specific dead code in const_expr(condition) to eliminate it. Critical for O rescale (n_kv_tiles > 1), LSE compute (not normalize), SMEM-P path.
  • CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512. Both unrolled (Python range) and runtime (cutlass.range unroll=1) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours.
  • Don't mix Python loops and pipeline ops. Python for unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer cutlass.range(unroll=1) for pipeline loops.

Math & merging

  • External k_sub merge is mathematically impossible. You cannot merge softmax(Q_k0 @ K_k0^T) @ V and softmax(Q_k1 @ K_k1^T) @ V into softmax(Q @ K^T) @ V. k_sub partitions are additive in logit space (S = S_0 + S_1); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over different token sets (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub.
  • D5 multi-tile KV merge IS valid. Per-segment LSE + the formula O = Σ exp(lse_i) · O_i / Σ exp(lse_i) works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix.
  • Sink merge = single softmax over [S_comp, S_swa + attn_sink]. The two-branch weighted merge formula in the paper is mathematically equivalent to adding attn_sink as a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d.

Numerics

  • Always test at hd=64 first. If the proven TMEM-P path regresses, nothing else matters.
  • St32x32bOp must be Float32, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.)
  • First PV ACCUMULATE=False. Otherwise sums uninitialized TMEM into the output and you see ~50% error.

Workflow

  • Never edit on the B200. Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost.
  • Print shapes inside @cute.kernel at trace time. print(f"tBgK shape: {cute.shape(tBgK)}") runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL.

SMEM budget

  • pv_n_tile is the easiest SMEM knob. At hd > 256, reducing pv_n_tile from 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling.
  • kv_stage is the second-easiest. Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB.
  • SMEM budget at various hd (with pv_n_tile=256 for hd≤256, pv_n_tile=128 for hd>256, kv_stage=2 for hd≤128 else 1):
hd sQ sK sV sP sC Total Limit
64 32 KB 32 KB 32 KB 32 KB 128 KB 232 KB
128 32 KB 32 KB 32 KB 32 KB 128 KB 232 KB
256 64 KB 64 KB 64 KB 0* 32 KB 224 KB 232 KB
512 64 KB 64 KB 32 KB 0* 32 KB 192 KB 232 KB

*TMEM-P path: sP allocation skipped via const_expr conditional.


Reference

  • DeepSeek V4 paper: DeepSeek_V4.pdf in the repo root.
  • DeepGEMM (V4-aligned reference kernels): https://github.com/deepseek-ai/DeepGEMM
  • CUTLASS FMHA reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py (B200) or /home/openclaw/dev/cutlass (local).
  • Reference oracles: dsv4/reference/ (PyTorch FP32 — slow, never imported by production code).
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%