Production-grade Blackwell SM100 inference kernel for **DeepSeek-V4-Pro NVFP4**, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).
This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions.
Frozen at construction time per `LayerSpec` so torch.compile constant-folds the dispatch. Validation in `dsv4/model/layer_schedule.py:validate_schedule` is loud — wrong schedule = silent garbage.
- Compresses every `m=4` KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 11–12 of the paper.
- Compressed sequence length is `n/m`.
- **Lightning indexer** scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps `csa_top_k` blocks (512 Flash / 1024 Pro).
- Core attention is MQA over the selected blocks + a sliding window branch of `n_win=128` raw tokens.
- Partial RoPE on the last 64 dims of Q and the compressed K, with **inverse RoPE on each per-head output** so the per-token contribution carries the correct relative position.
- Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below.
- DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token.
- L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down).
- SwiGLU clamping per paper §4.2.3: gate capped at `swiglu_limit=10`, linear clamped to `[-limit, +limit]`.
- All weights NVFP4, FP8 E4M3 scales, 16-element microblocks.
### Sink merge (D5c — key insight)
The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a **single softmax over `[S_comp, S_swa + attn_sink]`**.
One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed.
**Archived (Lineage P):** `dsv4/model/dsv4.py`, `dsv4/cache/*`, `dsv4/layers/{attention,ffn,norm,embedding}` — these were the vLLM/sglang integration surface but have 0 importers. See `_archive/` if needed.
- **NEVER raw SSH + direct command.** Always use the test harness scripts. They handle: killing hung processes, deleting stale logs, screen sessions that survive SSH drops, timeouts for hung kernels, and GPU cleanup.
3.**`.reduce(cute.ReductionOp.MAX)` reduces the entire C-fragment to a scalar** — global, not per-row. Use a plain `range()` loop with `cute.arch.fmax` for per-row max.
5.**Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips.
6.**CuTeDSL `if` blocks are separate MLIR regions.** Variables defined inside one `if` are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching.
7.**Guard dead code with `const_expr`.** CuTeDSL compiles both branches of Python `if`. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; without `const_expr`, MLIR chews on it.
8.**`tma_partition` and `flat_divide` may not survive inside `if warp_idx` blocks.** Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel calls `tma_partition` inside the epilogue warp's `if`, so this constraint may depend on context — print and verify.)
10.**`composition` vs `logical_divide` produce different layouts** even when re-tiling the same tensor. `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. Copy atoms must match the tensor layout they were created with.
- **TMA partition mode ordering** (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.7–0.9, never NaN — so you can ship it without knowing.
- **Square hides bugs.** (128,128) worked for every wrong approach to PV. Always test non-square shapes early.
- **Print the shapes always.** Reasoning about TMEM layouts or TMA mode counts without running `cute.printf(cute.shape(t))` inside `@cute.kernel` is how every multi-day debug starts. Shapes are ground truth.
- **`qk_mma_tiler` K-dim must equal `head_dim`**, not the MMA instruction's K sub-tile size. Hardcoding `qk_ik * 4 = 64` was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128.
- **Never assume TMEM round-trips are safe.** Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP.
- **FMHA P store uses QK C-fragment composition, not PV A-fragment.** Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
- **Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip the dual view.
- **TMEM round-trip mismatch with `epilogue_tma_store`**: `epilogue_tma_store` reads O from TMEM using `get_tmem_load_op`'s layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error.
- **CuTeDSL `if` blocks create separate MLIR regions.** Variables defined in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` inside an `if warp_idx < mma_warp_id:` are not visible. Define unconditionally before any branching.
- **CuTeDSL compiles both branches of Python `if`.** Wrap mode-specific dead code in `const_expr(condition)` to eliminate it. Critical for O rescale (`n_kv_tiles > 1`), LSE compute (`not normalize`), SMEM-P path.
- **Don't mix Python loops and pipeline ops.** Python `for` unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer `cutlass.range(unroll=1)` for pipeline loops.
- **External k_sub merge is mathematically impossible.** You cannot merge `softmax(Q_k0 @ K_k0^T) @ V` and `softmax(Q_k1 @ K_k1^T) @ V` into `softmax(Q @ K^T) @ V`. k_sub partitions are additive in **logit** space (`S = S_0 + S_1`); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over **different token sets** (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub.
- **D5 multi-tile KV merge IS valid.** Per-segment LSE + the formula `O = Σ exp(lse_i) · O_i / Σ exp(lse_i)` works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix.
- **Sink merge = single softmax over `[S_comp, S_swa + attn_sink]`.** The two-branch weighted merge formula in the paper is mathematically equivalent to adding `attn_sink` as a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d.
- **Always test at hd=64 first.** If the proven TMEM-P path regresses, nothing else matters.
- **`St32x32bOp` must be Float32**, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.)
- **First PV `ACCUMULATE=False`.** Otherwise sums uninitialized TMEM into the output and you see ~50% error.
- **Never edit on the B200.** Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost.
- **Print shapes inside `@cute.kernel` at trace time.** `print(f"tBgK shape: {cute.shape(tBgK)}")` runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL.
- **`pv_n_tile` is the easiest SMEM knob.** At hd > 256, reducing `pv_n_tile` from 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling.
- **`kv_stage` is the second-easiest.** Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB.
- **SMEM budget at various hd** (with `pv_n_tile=256` for hd≤256, `pv_n_tile=128` for hd>256, `kv_stage=2` for hd≤128 else 1):