# DSV4 Inference Kernel Production-grade Blackwell SM100 inference kernel for **DeepSeek-V4-Pro NVFP4**, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e). For what's done, what's blocked, and what's next, see **ROADMAP.md**. This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time. --- ## DSV4 is not MLA This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions. The differences that matter: | | MLA (V2/V3) | V4 | |---|---|---| | Compression axis | feature/head dim (per-token latent) | **sequence dim** (multiple tokens collapsed into one entry) | | Cache entries per token | one latent per token | one compressed entry per `m` tokens | | Attention pattern | dense over all cached latents | hybrid: sparse top-k (CSA) + dense over heavily-compressed (HCA) + sliding window (SWA) | | Compression rate | n/a (1:1) | m=4 for CSA, m'=128 for HCA | | Selection | none — all tokens attended | lightning indexer + top-k for CSA | | Output positional fix | n/a | inverse RoPE on each per-head output | | Sink merge | n/a | per-head learnable attention sink merged via single softmax over `[S_comp, S_swa + sink]` | Cache layout reflects this: per-layer **state cache** for SWA window + uncompressed tail (used for CSA/HCA compression), plus a **classical paged cache** holding compressed CSA/HCA entries, with block size = `lcm(m, m') = 128` original tokens per block. --- ## DSV4 architecture (paper-side reference) The bits the kernel implements, with the choices we made for inference. ### Per-layer attention type schedule ``` Flash (43 layers): layers 0-1 = SWA, layers 2..42 alternating CSA/HCA (CSA at layer 2) Pro (61 layers): layers 0-1 = HCA, layers 2..60 alternating CSA/HCA (CSA at layer 2) ``` Frozen at construction time per `LayerSpec` so torch.compile constant-folds the dispatch. Validation in `dsv4/model/layer_schedule.py:validate_schedule` is loud — wrong schedule = silent garbage. ### Compressed Sparse Attention (CSA) - Compresses every `m=4` KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 11–12 of the paper. - Compressed sequence length is `n/m`. - **Lightning indexer** scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps `csa_top_k` blocks (512 Flash / 1024 Pro). - Core attention is MQA over the selected blocks + a sliding window branch of `n_win=128` raw tokens. - Partial RoPE on the last 64 dims of Q and the compressed K, with **inverse RoPE on each per-head output** so the per-token contribution carries the correct relative position. - Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below. ### Heavily Compressed Attention (HCA) - Same compressor concept as CSA but `m'=128`, no overlap, dense attention over the (very short) compressed sequence. - No indexer. - Same partial RoPE + inverse RoPE + sliding window + sink as CSA. ### Sliding Window Attention (SWA) - First two layers of Flash. Pure local attention over the SWA window. No compressed branch, no indexer. - Cache layout: ring buffer of size `n_win` per request in the state cache. ### Manifold-Constrained Hyper-Connections (mHC) - Replaces residual connections. Width-expanded residual stream `(T, n_hc=4, d)`. - Per-token dynamic `A_l`, `B_l`, `C_l` mixing matrices generated by a fused 24-output prenorm projection (4 + 4² + 4). - `A_l = σ(.)`, `C_l = 2σ(.)`, `B_l = SinkhornKnopp(exp(.), t_max=20)` to project onto the Birkhoff polytope. - `pre_block`: `x_in = A_l @ X_l`; `post_block`: `X_next = B_l @ X_l + C_l ⊗ F_out`. - `B_l` held in FP32 for the bmm precision; A/C cast to BF16. ### Router - Two modes, frozen at construction by layer index: - **Hash routing** (layers 0–2): deterministic per-token-ID LUT lookup, uniform weights `1/k`. - **Dense routing** (layers 3+): `sqrt(softplus(X @ W_gate))` activation, plus learned `e_bias` for *selection only*. Top-k (k=6), renormalize on unbiased activations, multiply by `routed_scaling_factor`. ### MoE - DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token. - L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down). - SwiGLU clamping per paper §4.2.3: gate capped at `swiglu_limit=10`, linear clamped to `[-limit, +limit]`. - All weights NVFP4, FP8 E4M3 scales, 16-element microblocks. ### Sink merge (D5c — key insight) The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a **single softmax over `[S_comp, S_swa + attn_sink]`**. One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed. --- ## Our kernel design choices ### Attention kernel (FmhaKernel) **6-warp specialization.** Warps 0–3 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline). **P staging — two paths.** - **TMEM-P** (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree. - **SMEM-P** (hd > 64): P written to SMEM via coordinate-indexed store using `tTMEM_LOADcS` to map register indices to `(m, k)` then into `sP`'s subtile layout. PV reads P from SMEM with `OperandSource.SMEM`. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip. **Un-normalized O + LSE output.** The kernel emits raw `sum(P · V)` and `lse = ln(row_sum) + row_max · ln(2)`. External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it. **Per-head launch for multi-head.** Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using `flat_divide` + `tma_partition` is the next refactor (see ROADMAP); the path is unblocked once the correction-epilog rewrite lands. **Head-packed M dimension for decode.** Q reshaped to `(n_h * T, hd, 1)`, all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly. **K-dim sub-tiling at hd > 256.** When `head_dim > 256` (MMA instruction K-dim limit), Q and K split into `n_k_sub_tiles = head_dim / 256` chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses `pv_n_tile = 128` for hd > 256 to keep sV+sC within the 232 KB SMEM budget. **Sink bias as logit modification.** D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read `tTMEM_LOADcS` to get `(m, k)` coordinates and modify `tTMEM_LOADrS` before the row-max reduction. The sink bias is added in the raw-logit domain as `attn_sink / scale_softmax`, then the existing `* scale_log2` multiply converts to log2 space. ### MoE kernel (FusedSwiGLUScaledGroupedGemmKernel) **7-warp specialization.** Warps 0–3 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (`tcgen05.mma.block_scale` with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (`MoEStaticPersistentTileScheduler`). **One-way TMEM → registers → SMEM → GMEM epilogue.** Uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker — see ROADMAP. **Subtile-level gate/up pairing.** With granularity-8 interleaved L1 weights and `epi_tile_n=8`, even subtiles are gate and odd subtiles are up. `silu_gate_buf` register tensor carries the SiLU result across the subtile-pair boundary. **`use_2cta_instrs` conditional** on `tokens_sum ≥ 256` and even `cluster_m`. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.7–1.9× throughput). ### Heterogeneous KV cache - **State cache** per request: fixed-size block holding `(n_win SWA KV)` and `(uncompressed tail tokens awaiting compression)`. One block per request, lifetime managed by request scheduling. - **Classical paged cache** per request: variable blocks holding `(k1 CSA compressed entries, k2 HCA compressed entries)` per layer. `k1 = lcm(m, m') / m = 32`, `k2 = lcm(m, m') / m' = 1`. Block covers 128 original tokens. - Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool. ### NVFP4 throughout - **Weights**: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified: `sf_dtype`, TMA element type, MMA kind (`mxf4nvf4`) all correct. - **Activations**: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands (see ROADMAP). - **KV cache**: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2. - **Indexer keys**: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority. --- ## Package structure ``` dsv4/ ├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files) │ ├── attention/ FMHA — FmhaKernel (hd=64/128/256 proven, hd=512 MLIR-blocked) │ ├── gemm/ NVFP4 MoE GEMM (grouped, fused_swiglu, dense, scheduler) │ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL) │ ├── indexer/ CSA indexer score+topk (FP32 scalar today; tensor-core FP4 on roadmap) │ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM) │ ├── cache/ append_swa (writes KV to state cache) │ ├── decode/ Decode-time attention (future) │ └── cuda/ Raw .cu (deinterleave_quantize, sparse_topk_metadata, etc.) ├── ops/ PyTorch ↔ kernel bridges │ ├── quantize.py BF16 ↔ NVFP4, scale factor handling │ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets │ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs │ ├── custom_ops.py torch.library.custom_op registrations │ ├── decode_sparse.py native_sparse_decode dispatcher │ ├── rope.py Forward + inverse RoPE (partial, last 64 dims) │ ├── topk.py Sparse top-k metadata wrapper │ └── router.py Router op bridge ├── layers/ nn.Module-style components │ ├── linear.py Nvfp4Linear │ ├── grouped_linear.py Nvfp4GroupedLinear (output projection) │ ├── moe.py Nvfp4MoE (routed experts) │ ├── shared_expert.py Nvfp4SharedExpert │ ├── mhc.py mHCLayer (Sinkhorn-Knopp, residual mixing) │ ├── attention.py AttentionSubBlock (CSA/HCA/SWA variants by LayerSpec) │ ├── norm.py RMSNorm │ ├── router.py Router (dense + hash modes) │ ├── embedding.py Token embedding + mHC init │ └── ffn.py FFN sub-block ├── model/ Model assembly │ ├── config.py DSV4Config │ ├── layer.py TransformerLayer │ ├── layer_schedule.py LayerSpec, AttentionType, build_schedule, validate_schedule │ ├── mtp.py Multi-token prediction │ ├── sampler.py Token sampler │ └── dsv4.py Full model ├── cache/ KV cache infra │ ├── allocator.py Memory allocator │ ├── block_table.py Paged cache block table │ ├── manager.py Cache manager │ ├── paged_cache.py Classical paged cache (CSA/HCA) │ ├── state_cache.py State cache (SWA + uncompressed tail) │ ├── schema.py, handle.py, flush.py, prepare_forward.py ├── loader/ Checkpoint I/O │ ├── hf_checkpoint.py │ └── layout_convert.py └── reference/ Slow PyTorch oracles (never imported by production code) ├── attention.py, csa_attention.py, compressor.py, moe_pipeline.py ``` **Dependency arrow:** `kernels/` → `ops/` → `layers/` → `model/`. `reference/` and `loader/` are sidecars. --- ## Workflow & test harness ### The non-negotiables - **NEVER edit on the B200.** Always: edit locally → commit → push → pull on B200 → test. - **NEVER raw SSH + direct command.** Always use the test harness scripts. They handle: killing hung processes, deleting stale logs, screen sessions that survive SSH drops, timeouts for hung kernels, and GPU cleanup. - **ALWAYS verify hd=64 regression** (cos ~0.999998) after every FMHA change. If it regresses, the change is wrong. Revert. - **NEVER touch drivers, kernels, firmware, or system packages** on the B200. - **NEVER delete test files** in `tests/unit/` without explicit approval. ### Two harnesses: Python and CUDA | Harness | For | Script | Screen name | Log file | |---|---|---|---|---| | Python | `test_*.py` files | `fire_b200_test` | `kernel-test` | `/tmp/kernel-test.log` | | CUDA | `test_*.cu` files | `fire_b200_cuda_test` | `cuda-test` | `/tmp/cuda-test.log` | Both harnesses follow the same discipline: 1. **Kill everything first** — old screen sessions, hanging GPU processes, stale binaries 2. **Delete all logs** — never debug from a previous run's log 3. **Clean git + pull** — no uncommitted B200 state 4. **Run in screen** — survives SSH drops, has a timeout 5. **One test at a time** — no parallel launches, ever ### Python test (one command) ```bash # From local machine — auto-pushes, runs, polls, dumps log ~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py ``` ### CUDA test (one command) ```bash # From local machine — compiles with nvcc, runs, polls, dumps log # Default timeout: 60s. Pass a second arg for custom timeout. ~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_fmha_sm100_standalone.cu ~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30 ``` ### Check on a running CUDA test ```bash # Show current log + screen status ~/.openclaw/workspace/check_b200_cuda # Kill a hung test + show the log ~/.openclaw/workspace/check_b200_cuda kill ``` ### Manual B200 cycle (emergency only) ```bash ssh root@ cd /root/dsv4-nvfp4-workspace/kernel && git pull bash tests/run_test.sh tests/unit/test_<...>.py bash tests/check_log.sh ``` `run_test.sh` kills any prior `kernel-test` screen (with SIGKILL on stuck GPU procs), deletes the old log, starts a fresh `screen -dmS kernel-test`, and logs to `/tmp/kernel-test.log`. ### Environment - **B200 access**: see `MEMORY.md` (not committed). - **venv**: `source /root/dsv4-nvfp4-workspace/venv/bin/activate` - **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel` - **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4` - **vLLM** (modified for Blackwell): `/root/dsv4-nvfp4-workspace/vllm` - **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` - **Local CUTLASS clone**: `/home/openclaw/dev/cutlass` --- ## CuTeDSL constraints (read every session) These are surface-level traps. Get them wrong and the kernel silently produces garbage, NaN, or "weakly congruent" at JIT compile time. 1. **TMA partition tensors have 4 modes**: `(((64,128),1), ?, KV_tiles, ?)`. `(None, 0, None, 0)` keeps mode 2 (KV tiles) free; `[None, kt]` indexes it. `(None, None, 0, 0)` silently pins mode 2 to 0 — multi-tile loads break invisibly. 2. **`vectorize=True` loops accept only load/store/print.** No `fmax`, no `cmpf`, no inner loops, no carry across iterations. 3. **`.reduce(cute.ReductionOp.MAX)` reduces the entire C-fragment to a scalar** — global, not per-row. Use a plain `range()` loop with `cute.arch.fmax` for per-row max. 4. **`cute.arch.fmax` is impure** for the vectorizer. Use it inside plain `range()`, never inside `vectorize=True`. 5. **Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips. This is the D1.5 blocker in ROADMAP. 6. **CuTeDSL `if` blocks are separate MLIR regions.** Variables defined inside one `if` are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching. 7. **Guard dead code with `const_expr`.** CuTeDSL compiles both branches of Python `if`. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; without `const_expr`, MLIR chews on it. 8. **`tma_partition` and `flat_divide` may not survive inside `if warp_idx` blocks.** Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel calls `tma_partition` inside the epilogue warp's `if`, so this constraint may depend on context — print and verify.) 9. **TMEM allocation must be a power of 2.** Round up after summing column requirements. 10. **`composition` vs `logical_divide` produce different layouts** even when re-tiling the same tensor. `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. Copy atoms must match the tensor layout they were created with. 11. **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN. 12. **`St32x32bOp` must use Float32, not BFloat16.** BFloat16 causes illegal memory access. 13. **First PV must have `ACCUMULATE=False`.** Otherwise adds uninitialized TMEM contents to the output. 14. **`find_tmem_tensor_col_offset()` returns footprint size, not a safe offset.** Never use it as a TMEM placement. 15. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as `(hd, s_k)` MN-major inside CuTe via explicit `make_tensor` + `make_layout`. --- ## Lessons learned (the gold — read every session) These cost real days to learn. They are listed in priority of how easy they are to repeat. ### Layout & TMA - **TMA partition mode ordering** (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.7–0.9, never NaN — so you can ship it without knowing. - **Square hides bugs.** (128,128) worked for every wrong approach to PV. Always test non-square shapes early. - **Print the shapes always.** Reasoning about TMEM layouts or TMA mode counts without running `cute.printf(cute.shape(t))` inside `@cute.kernel` is how every multi-day debug starts. Shapes are ground truth. - **`qk_mma_tiler` K-dim must equal `head_dim`**, not the MMA instruction's K sub-tile size. Hardcoding `qk_ik * 4 = 64` was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128. ### TMEM - **Never assume TMEM round-trips are safe.** Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP. - **FMHA P store uses QK C-fragment composition, not PV A-fragment.** Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage. - **Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip the dual view. - **TMEM round-trip mismatch with `epilogue_tma_store`**: `epilogue_tma_store` reads O from TMEM using `get_tmem_load_op`'s layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error. - **The correction-epilog pattern is the fix.** TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results. See ROADMAP. ### CuTeDSL & MLIR - **CuTeDSL `if` blocks create separate MLIR regions.** Variables defined in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` inside an `if warp_idx < mma_warp_id:` are not visible. Define unconditionally before any branching. - **CuTeDSL compiles both branches of Python `if`.** Wrap mode-specific dead code in `const_expr(condition)` to eliminate it. Critical for O rescale (`n_kv_tiles > 1`), LSE compute (`not normalize`), SMEM-P path. - **CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512.** Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours. Workaround options in ROADMAP. - **Don't mix Python loops and pipeline ops.** Python `for` unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer `cutlass.range(unroll=1)` for pipeline loops. ### Math & merging - **External k_sub merge is mathematically impossible.** You cannot merge `softmax(Q_k0 @ K_k0^T) @ V` and `softmax(Q_k1 @ K_k1^T) @ V` into `softmax(Q @ K^T) @ V`. k_sub partitions are additive in **logit** space (`S = S_0 + S_1`); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over **different token sets** (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub. - **D5 multi-tile KV merge IS valid.** Per-segment LSE + the formula `O = Σ exp(lse_i) · O_i / Σ exp(lse_i)` works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix. - **Sink merge = single softmax over `[S_comp, S_swa + attn_sink]`.** The two-branch weighted merge formula in the paper is mathematically equivalent to adding `attn_sink` as a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d. ### Numerics - **Always test at hd=64 first.** If the proven TMEM-P path regresses, nothing else matters. - **`St32x32bOp` must be Float32**, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.) - **First PV `ACCUMULATE=False`.** Otherwise sums uninitialized TMEM into the output and you see ~50% error. ### Workflow - **Never edit on the B200.** Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost. - **Print shapes inside `@cute.kernel` at trace time.** `print(f"tBgK shape: {cute.shape(tBgK)}")` runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL. ### SMEM budget - **`pv_n_tile` is the easiest SMEM knob.** At hd > 256, reducing `pv_n_tile` from 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling. - **`kv_stage` is the second-easiest.** Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB. - **SMEM budget at various hd** (with `pv_n_tile=256` for hd≤256, `pv_n_tile=128` for hd>256, `kv_stage=2` for hd≤128 else 1): | hd | sQ | sK | sV | sP | sC | Total | Limit | |---:|---:|---:|---:|---:|---:|------:|------:| | 64 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB | | 128 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB | | 256 | 64 KB | 64 KB | 64 KB | 0* | 32 KB | 224 KB | 232 KB | | 512 | 64 KB | 64 KB | 32 KB | 0* | 32 KB | 192 KB | 232 KB | *TMEM-P path: sP allocation skipped via `const_expr` conditional. --- ## Reference - **DeepSeek V4 paper**: `DeepSeek_V4.pdf` in the repo root. - **DeepGEMM** (V4-aligned reference kernels): https://github.com/deepseek-ai/DeepGEMM - **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (B200) or `/home/openclaw/dev/cutlass` (local). - **Reference oracles**: `dsv4/reference/` (PyTorch FP32 — slow, never imported by production code).