Files
nvfp4-megamoe-kernel/README.md

368 lines
24 KiB
Markdown
Raw Normal View History

# DSV4 Inference Kernel
Production-grade Blackwell SM100 inference kernel for **DeepSeek-V4-Pro NVFP4**, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).
For what's done, what's blocked, and what's next, see **ROADMAP.md**. This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
---
## DSV4 is not MLA
This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions.
The differences that matter:
| | MLA (V2/V3) | V4 |
|---|---|---|
| Compression axis | feature/head dim (per-token latent) | **sequence dim** (multiple tokens collapsed into one entry) |
| Cache entries per token | one latent per token | one compressed entry per `m` tokens |
| Attention pattern | dense over all cached latents | hybrid: sparse top-k (CSA) + dense over heavily-compressed (HCA) + sliding window (SWA) |
| Compression rate | n/a (1:1) | m=4 for CSA, m'=128 for HCA |
| Selection | none — all tokens attended | lightning indexer + top-k for CSA |
| Output positional fix | n/a | inverse RoPE on each per-head output |
| Sink merge | n/a | per-head learnable attention sink merged via single softmax over `[S_comp, S_swa + sink]` |
Cache layout reflects this: per-layer **state cache** for SWA window + uncompressed tail (used for CSA/HCA compression), plus a **classical paged cache** holding compressed CSA/HCA entries, with block size = `lcm(m, m') = 128` original tokens per block.
---
## DSV4 architecture (paper-side reference)
The bits the kernel implements, with the choices we made for inference.
### Per-layer attention type schedule
```
Flash (43 layers): layers 0-1 = SWA, layers 2..42 alternating CSA/HCA (CSA at layer 2)
Pro (61 layers): layers 0-1 = HCA, layers 2..60 alternating CSA/HCA (CSA at layer 2)
```
Frozen at construction time per `LayerSpec` so torch.compile constant-folds the dispatch. Validation in `dsv4/model/layer_schedule.py:validate_schedule` is loud — wrong schedule = silent garbage.
### Compressed Sparse Attention (CSA)
- Compresses every `m=4` KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 1112 of the paper.
- Compressed sequence length is `n/m`.
- **Lightning indexer** scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps `csa_top_k` blocks (512 Flash / 1024 Pro).
- Core attention is MQA over the selected blocks + a sliding window branch of `n_win=128` raw tokens.
- Partial RoPE on the last 64 dims of Q and the compressed K, with **inverse RoPE on each per-head output** so the per-token contribution carries the correct relative position.
- Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below.
### Heavily Compressed Attention (HCA)
- Same compressor concept as CSA but `m'=128`, no overlap, dense attention over the (very short) compressed sequence.
- No indexer.
- Same partial RoPE + inverse RoPE + sliding window + sink as CSA.
### Sliding Window Attention (SWA)
- First two layers of Flash. Pure local attention over the SWA window. No compressed branch, no indexer.
- Cache layout: ring buffer of size `n_win` per request in the state cache.
### Manifold-Constrained Hyper-Connections (mHC)
- Replaces residual connections. Width-expanded residual stream `(T, n_hc=4, d)`.
- Per-token dynamic `A_l`, `B_l`, `C_l` mixing matrices generated by a fused 24-output prenorm projection (4 + 4² + 4).
- `A_l = σ(.)`, `C_l = 2σ(.)`, `B_l = SinkhornKnopp(exp(.), t_max=20)` to project onto the Birkhoff polytope.
- `pre_block`: `x_in = A_l @ X_l`; `post_block`: `X_next = B_l @ X_l + C_l ⊗ F_out`.
- `B_l` held in FP32 for the bmm precision; A/C cast to BF16.
### Router
- Two modes, frozen at construction by layer index:
- **Hash routing** (layers 02): deterministic per-token-ID LUT lookup, uniform weights `1/k`.
- **Dense routing** (layers 3+): `sqrt(softplus(X @ W_gate))` activation, plus learned `e_bias` for *selection only*. Top-k (k=6), renormalize on unbiased activations, multiply by `routed_scaling_factor`.
### MoE
- DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token.
- L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down).
- SwiGLU clamping per paper §4.2.3: gate capped at `swiglu_limit=10`, linear clamped to `[-limit, +limit]`.
- All weights NVFP4, FP8 E4M3 scales, 16-element microblocks.
### Sink merge (D5c — key insight)
The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a **single softmax over `[S_comp, S_swa + attn_sink]`**.
One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed.
---
## Our kernel design choices
### Attention kernel (FmhaKernel)
**6-warp specialization.** Warps 03 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline).
**P staging — two paths.**
- **TMEM-P** (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree.
- **SMEM-P** (hd > 64): P written to SMEM via coordinate-indexed store using `tTMEM_LOADcS` to map register indices to `(m, k)` then into `sP`'s subtile layout. PV reads P from SMEM with `OperandSource.SMEM`. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip.
**Un-normalized O + LSE output.** The kernel emits raw `sum(P · V)` and `lse = ln(row_sum) + row_max · ln(2)`. External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it.
**Per-head launch for multi-head.** Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using `flat_divide` + `tma_partition` is the next refactor (see ROADMAP); the path is unblocked once the correction-epilog rewrite lands.
**Head-packed M dimension for decode.** Q reshaped to `(n_h * T, hd, 1)`, all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly.
**K-dim sub-tiling at hd > 256.** When `head_dim > 256` (MMA instruction K-dim limit), Q and K split into `n_k_sub_tiles = head_dim / 256` chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses `pv_n_tile = 128` for hd > 256 to keep sV+sC within the 232 KB SMEM budget.
**Sink bias as logit modification.** D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read `tTMEM_LOADcS` to get `(m, k)` coordinates and modify `tTMEM_LOADrS` before the row-max reduction. The sink bias is added in the raw-logit domain as `attn_sink / scale_softmax`, then the existing `* scale_log2` multiply converts to log2 space.
### MoE kernel (FusedSwiGLUScaledGroupedGemmKernel)
**7-warp specialization.** Warps 03 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (`tcgen05.mma.block_scale` with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (`MoEStaticPersistentTileScheduler`).
**One-way TMEM → registers → SMEM → GMEM epilogue.** Uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker — see ROADMAP.
**Subtile-level gate/up pairing.** With granularity-8 interleaved L1 weights and `epi_tile_n=8`, even subtiles are gate and odd subtiles are up. `silu_gate_buf` register tensor carries the SiLU result across the subtile-pair boundary.
**`use_2cta_instrs` conditional** on `tokens_sum ≥ 256` and even `cluster_m`. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.71.9× throughput).
### Heterogeneous KV cache
- **State cache** per request: fixed-size block holding `(n_win SWA KV)` and `(uncompressed tail tokens awaiting compression)`. One block per request, lifetime managed by request scheduling.
- **Classical paged cache** per request: variable blocks holding `(k1 CSA compressed entries, k2 HCA compressed entries)` per layer. `k1 = lcm(m, m') / m = 32`, `k2 = lcm(m, m') / m' = 1`. Block covers 128 original tokens.
- Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool.
### NVFP4 throughout
- **Weights**: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified: `sf_dtype`, TMA element type, MMA kind (`mxf4nvf4`) all correct.
- **Activations**: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands (see ROADMAP).
- **KV cache**: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2.
- **Indexer keys**: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority.
---
## Package structure
```
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── attention/ FMHA — FmhaKernel (hd=64/128/256 proven, hd=512 MLIR-blocked)
│ ├── gemm/ NVFP4 MoE GEMM (grouped, fused_swiglu, dense, scheduler)
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL)
│ ├── indexer/ CSA indexer score+topk (FP32 scalar today; tensor-core FP4 on roadmap)
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
│ ├── cache/ append_swa (writes KV to state cache)
│ ├── decode/ Decode-time attention (future)
│ └── cuda/ Raw .cu (deinterleave_quantize, sparse_topk_metadata, etc.)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4, scale factor handling
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── rope.py Forward + inverse RoPE (partial, last 64 dims)
│ ├── topk.py Sparse top-k metadata wrapper
│ └── router.py Router op bridge
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear (output projection)
│ ├── moe.py Nvfp4MoE (routed experts)
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer (Sinkhorn-Knopp, residual mixing)
│ ├── attention.py AttentionSubBlock (CSA/HCA/SWA variants by LayerSpec)
│ ├── norm.py RMSNorm
│ ├── router.py Router (dense + hash modes)
│ ├── embedding.py Token embedding + mHC init
│ └── ffn.py FFN sub-block
├── model/ Model assembly
│ ├── config.py DSV4Config
│ ├── layer.py TransformerLayer
│ ├── layer_schedule.py LayerSpec, AttentionType, build_schedule, validate_schedule
│ ├── mtp.py Multi-token prediction
│ ├── sampler.py Token sampler
│ └── dsv4.py Full model
├── cache/ KV cache infra
│ ├── allocator.py Memory allocator
│ ├── block_table.py Paged cache block table
│ ├── manager.py Cache manager
│ ├── paged_cache.py Classical paged cache (CSA/HCA)
│ ├── state_cache.py State cache (SWA + uncompressed tail)
│ ├── schema.py, handle.py, flush.py, prepare_forward.py
├── loader/ Checkpoint I/O
│ ├── hf_checkpoint.py
│ └── layout_convert.py
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py, csa_attention.py, compressor.py, moe_pipeline.py
```
**Dependency arrow:** `kernels/``ops/``layers/``model/`. `reference/` and `loader/` are sidecars.
---
## Workflow & test harness
### The non-negotiables
2026-05-22 17:09:53 +00:00
- **NEVER edit on the B200.** Always: edit locally → commit → push → pull on B200 → test.
- **NEVER raw SSH + direct command.** Always use the test harness scripts. They handle: killing hung processes, deleting stale logs, screen sessions that survive SSH drops, timeouts for hung kernels, and GPU cleanup.
- **ALWAYS verify hd=64 regression** (cos ~0.999998) after every FMHA change. If it regresses, the change is wrong. Revert.
- **NEVER touch drivers, kernels, firmware, or system packages** on the B200.
- **NEVER delete test files** in `tests/unit/` without explicit approval.
2026-05-22 17:09:53 +00:00
### Two harnesses: Python and CUDA
2026-05-22 17:09:53 +00:00
| Harness | For | Script | Screen name | Log file |
|---|---|---|---|---|
| Python | `test_*.py` files | `fire_b200_test` | `kernel-test` | `/tmp/kernel-test.log` |
| CUDA | `test_*.cu` files | `fire_b200_cuda_test` | `cuda-test` | `/tmp/cuda-test.log` |
Both harnesses follow the same discipline:
1. **Kill everything first** — old screen sessions, hanging GPU processes, stale binaries
2. **Delete all logs** — never debug from a previous run's log
3. **Clean git + pull** — no uncommitted B200 state
4. **Run in screen** — survives SSH drops, has a timeout
5. **One test at a time** — no parallel launches, ever
2026-05-22 17:09:53 +00:00
### Python test (one command)
```bash
# From local machine — auto-pushes, runs, polls, dumps log
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py
2026-05-22 17:09:53 +00:00
```
### CUDA test (one command)
```bash
# From local machine — compiles with nvcc, runs, polls, dumps log
# Default timeout: 60s. Pass a second arg for custom timeout.
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_fmha_sm100_standalone.cu
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30
```
### Check on a running CUDA test
```bash
# Show current log + screen status
~/.openclaw/workspace/check_b200_cuda
# Kill a hung test + show the log
~/.openclaw/workspace/check_b200_cuda kill
```
### Manual B200 cycle (emergency only)
2026-05-22 17:09:53 +00:00
```bash
ssh root@<B200>
2026-05-22 17:09:53 +00:00
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_<...>.py
2026-05-22 17:09:53 +00:00
bash tests/check_log.sh
```
`run_test.sh` kills any prior `kernel-test` screen (with SIGKILL on stuck GPU procs), deletes the old log, starts a fresh `screen -dmS kernel-test`, and logs to `/tmp/kernel-test.log`.
### Environment
- **B200 access**: see `MEMORY.md` (not committed).
- **venv**: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
- **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel`
- **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- **vLLM** (modified for Blackwell): `/root/dsv4-nvfp4-workspace/vllm`
- **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
- **Local CUTLASS clone**: `/home/openclaw/dev/cutlass`
2026-05-22 17:09:53 +00:00
---
## CuTeDSL constraints (read every session)
These are surface-level traps. Get them wrong and the kernel silently produces garbage, NaN, or "weakly congruent" at JIT compile time.
1. **TMA partition tensors have 4 modes**: `(((64,128),1), ?, KV_tiles, ?)`. `(None, 0, None, 0)` keeps mode 2 (KV tiles) free; `[None, kt]` indexes it. `(None, None, 0, 0)` silently pins mode 2 to 0 — multi-tile loads break invisibly.
2. **`vectorize=True` loops accept only load/store/print.** No `fmax`, no `cmpf`, no inner loops, no carry across iterations.
3. **`.reduce(cute.ReductionOp.MAX)` reduces the entire C-fragment to a scalar** — global, not per-row. Use a plain `range()` loop with `cute.arch.fmax` for per-row max.
4. **`cute.arch.fmax` is impure** for the vectorizer. Use it inside plain `range()`, never inside `vectorize=True`.
5. **Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips. This is the D1.5 blocker in ROADMAP.
6. **CuTeDSL `if` blocks are separate MLIR regions.** Variables defined inside one `if` are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching.
feat: CUTLASS NVFP4 mega_moe kernel — slot-based L1/L2, source-first SF remap Major changes from initial TileLang prototype: Kernel: - CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp) - Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter - 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild) - slot_token gathered in cutlass_grouped_nvfp4_gemm when provided SF Remap (source-first): - Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf)) for CUTLASS dest index — no idx2crd/flatten coordinate extraction - 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN) - Uses cute::cosize() for physical allocation size (not cute::size) - SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major) Weight transform: - UE4M3 unpack with bit reinterpret (not value cast) - Global scale folding (weight_scale_2) for gate/up split - clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS No prepack cache: - SFB remapped per-call inside CUTLASS (~µs, not the bottleneck) - See README for why prepack cache must never return (OOM, CUDA graphs, M-dependent layout, cross-layer collisions) Stage activation: - Nearest-neighbor E2M1 quantization (no clamp, no uniform steps) - Per-tensor global scale → alpha for L2 GEMM Bug fixes: - _fold_global_scale: removed broken logical_widths branch - unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support - Correct expert param mapping for NVFP4 checkpoint - SiLU applied per-slot (not after summing expert paths)
2026-05-15 11:38:18 +00:00
7. **Guard dead code with `const_expr`.** CuTeDSL compiles both branches of Python `if`. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; without `const_expr`, MLIR chews on it.
8. **`tma_partition` and `flat_divide` may not survive inside `if warp_idx` blocks.** Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel calls `tma_partition` inside the epilogue warp's `if`, so this constraint may depend on context — print and verify.)
9. **TMEM allocation must be a power of 2.** Round up after summing column requirements.
10. **`composition` vs `logical_divide` produce different layouts** even when re-tiling the same tensor. `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. Copy atoms must match the tensor layout they were created with.
11. **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN.
12. **`St32x32bOp` must use Float32, not BFloat16.** BFloat16 causes illegal memory access.
13. **First PV must have `ACCUMULATE=False`.** Otherwise adds uninitialized TMEM contents to the output.
14. **`find_tmem_tensor_col_offset()` returns footprint size, not a safe offset.** Never use it as a TMEM placement.
15. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as `(hd, s_k)` MN-major inside CuTe via explicit `make_tensor` + `make_layout`.
---
2026-05-20 04:39:57 +00:00
## Lessons learned (the gold — read every session)
These cost real days to learn. They are listed in priority of how easy they are to repeat.
### Layout & TMA
- **TMA partition mode ordering** (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.70.9, never NaN — so you can ship it without knowing.
- **Square hides bugs.** (128,128) worked for every wrong approach to PV. Always test non-square shapes early.
- **Print the shapes always.** Reasoning about TMEM layouts or TMA mode counts without running `cute.printf(cute.shape(t))` inside `@cute.kernel` is how every multi-day debug starts. Shapes are ground truth.
- **`qk_mma_tiler` K-dim must equal `head_dim`**, not the MMA instruction's K sub-tile size. Hardcoding `qk_ik * 4 = 64` was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128.
### TMEM
- **Never assume TMEM round-trips are safe.** Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP.
- **FMHA P store uses QK C-fragment composition, not PV A-fragment.** Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
- **Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip the dual view.
- **TMEM round-trip mismatch with `epilogue_tma_store`**: `epilogue_tma_store` reads O from TMEM using `get_tmem_load_op`'s layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error.
- **The correction-epilog pattern is the fix.** TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results. See ROADMAP.
### CuTeDSL & MLIR
- **CuTeDSL `if` blocks create separate MLIR regions.** Variables defined in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` inside an `if warp_idx < mma_warp_id:` are not visible. Define unconditionally before any branching.
- **CuTeDSL compiles both branches of Python `if`.** Wrap mode-specific dead code in `const_expr(condition)` to eliminate it. Critical for O rescale (`n_kv_tiles > 1`), LSE compute (`not normalize`), SMEM-P path.
- **CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512.** Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours. Workaround options in ROADMAP.
- **Don't mix Python loops and pipeline ops.** Python `for` unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer `cutlass.range(unroll=1)` for pipeline loops.
### Math & merging
- **External k_sub merge is mathematically impossible.** You cannot merge `softmax(Q_k0 @ K_k0^T) @ V` and `softmax(Q_k1 @ K_k1^T) @ V` into `softmax(Q @ K^T) @ V`. k_sub partitions are additive in **logit** space (`S = S_0 + S_1`); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over **different token sets** (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub.
- **D5 multi-tile KV merge IS valid.** Per-segment LSE + the formula `O = Σ exp(lse_i) · O_i / Σ exp(lse_i)` works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix.
- **Sink merge = single softmax over `[S_comp, S_swa + attn_sink]`.** The two-branch weighted merge formula in the paper is mathematically equivalent to adding `attn_sink` as a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d.
### Numerics
- **Always test at hd=64 first.** If the proven TMEM-P path regresses, nothing else matters.
- **`St32x32bOp` must be Float32**, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.)
- **First PV `ACCUMULATE=False`.** Otherwise sums uninitialized TMEM into the output and you see ~50% error.
### Workflow
- **Never edit on the B200.** Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost.
- **Print shapes inside `@cute.kernel` at trace time.** `print(f"tBgK shape: {cute.shape(tBgK)}")` runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL.
### SMEM budget
- **`pv_n_tile` is the easiest SMEM knob.** At hd > 256, reducing `pv_n_tile` from 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling.
- **`kv_stage` is the second-easiest.** Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB.
- **SMEM budget at various hd** (with `pv_n_tile=256` for hd≤256, `pv_n_tile=128` for hd>256, `kv_stage=2` for hd≤128 else 1):
| hd | sQ | sK | sV | sP | sC | Total | Limit |
|---:|---:|---:|---:|---:|---:|------:|------:|
| 64 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
| 128 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
| 256 | 64 KB | 64 KB | 64 KB | 0* | 32 KB | 224 KB | 232 KB |
| 512 | 64 KB | 64 KB | 32 KB | 0* | 32 KB | 192 KB | 232 KB |
*TMEM-P path: sP allocation skipped via `const_expr` conditional.
---
## Reference
- **DeepSeek V4 paper**: `DeepSeek_V4.pdf` in the repo root.
- **DeepGEMM** (V4-aligned reference kernels): https://github.com/deepseek-ai/DeepGEMM
- **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (B200) or `/home/openclaw/dev/cutlass` (local).
- **Reference oracles**: `dsv4/reference/` (PyTorch FP32 — slow, never imported by production code).