353 lines
22 KiB
Markdown
353 lines
22 KiB
Markdown
# DSV4 Inference Kernel
|
||
|
||
Production-grade Blackwell SM100 inference kernel for **DeepSeek-V4-Pro NVFP4**, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).
|
||
|
||
|
||
This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
|
||
|
||
---
|
||
|
||
## DSV4 is not MLA
|
||
|
||
This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions.
|
||
|
||
The differences that matter:
|
||
|
||
| | MLA (V2/V3) | V4 |
|
||
|---|---|---|
|
||
| Compression axis | feature/head dim (per-token latent) | **sequence dim** (multiple tokens collapsed into one entry) |
|
||
| Cache entries per token | one latent per token | one compressed entry per `m` tokens |
|
||
| Attention pattern | dense over all cached latents | hybrid: sparse top-k (CSA) + dense over heavily-compressed (HCA) + sliding window (SWA) |
|
||
| Compression rate | n/a (1:1) | m=4 for CSA, m'=128 for HCA |
|
||
| Selection | none — all tokens attended | lightning indexer + top-k for CSA |
|
||
| Output positional fix | n/a | inverse RoPE on each per-head output |
|
||
| Sink merge | n/a | per-head learnable attention sink merged via single softmax over `[S_comp, S_swa + sink]` |
|
||
|
||
Cache layout reflects this: per-layer **state cache** for SWA window + uncompressed tail (used for CSA/HCA compression), plus a **classical paged cache** holding compressed CSA/HCA entries, with block size = `lcm(m, m') = 128` original tokens per block.
|
||
|
||
---
|
||
|
||
## DSV4 architecture (paper-side reference)
|
||
|
||
The bits the kernel implements, with the choices we made for inference.
|
||
|
||
### Per-layer attention type schedule
|
||
|
||
```
|
||
Flash (43 layers): layers 0-1 = SWA, layers 2..42 alternating CSA/HCA (CSA at layer 2)
|
||
Pro (61 layers): layers 0-1 = HCA, layers 2..60 alternating CSA/HCA (CSA at layer 2)
|
||
```
|
||
|
||
Frozen at construction time per `LayerSpec` so torch.compile constant-folds the dispatch. Validation in `dsv4/model/layer_schedule.py:validate_schedule` is loud — wrong schedule = silent garbage.
|
||
|
||
### Compressed Sparse Attention (CSA)
|
||
|
||
- Compresses every `m=4` KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 11–12 of the paper.
|
||
- Compressed sequence length is `n/m`.
|
||
- **Lightning indexer** scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps `csa_top_k` blocks (512 Flash / 1024 Pro).
|
||
- Core attention is MQA over the selected blocks + a sliding window branch of `n_win=128` raw tokens.
|
||
- Partial RoPE on the last 64 dims of Q and the compressed K, with **inverse RoPE on each per-head output** so the per-token contribution carries the correct relative position.
|
||
- Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below.
|
||
|
||
### Heavily Compressed Attention (HCA)
|
||
|
||
- Same compressor concept as CSA but `m'=128`, no overlap, dense attention over the (very short) compressed sequence.
|
||
- No indexer.
|
||
- Same partial RoPE + inverse RoPE + sliding window + sink as CSA.
|
||
|
||
### Sliding Window Attention (SWA)
|
||
|
||
- First two layers of Flash. Pure local attention over the SWA window. No compressed branch, no indexer.
|
||
- Cache layout: ring buffer of size `n_win` per request in the state cache.
|
||
|
||
### Manifold-Constrained Hyper-Connections (mHC)
|
||
|
||
- Replaces residual connections. Width-expanded residual stream `(T, n_hc=4, d)`.
|
||
- Per-token dynamic `A_l`, `B_l`, `C_l` mixing matrices generated by a fused 24-output prenorm projection (4 + 4² + 4).
|
||
- `A_l = σ(.)`, `C_l = 2σ(.)`, `B_l = SinkhornKnopp(exp(.), t_max=20)` to project onto the Birkhoff polytope.
|
||
- `pre_block`: `x_in = A_l @ X_l`; `post_block`: `X_next = B_l @ X_l + C_l ⊗ F_out`.
|
||
- `B_l` held in FP32 for the bmm precision; A/C cast to BF16.
|
||
|
||
### Router
|
||
|
||
- Two modes, frozen at construction by layer index:
|
||
- **Hash routing** (layers 0–2): deterministic per-token-ID LUT lookup, uniform weights `1/k`.
|
||
- **Dense routing** (layers 3+): `sqrt(softplus(X @ W_gate))` activation, plus learned `e_bias` for *selection only*. Top-k (k=6), renormalize on unbiased activations, multiply by `routed_scaling_factor`.
|
||
|
||
### MoE
|
||
|
||
- DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token.
|
||
- L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down).
|
||
- SwiGLU clamping per paper §4.2.3: gate capped at `swiglu_limit=10`, linear clamped to `[-limit, +limit]`.
|
||
- All weights NVFP4, FP8 E4M3 scales, 16-element microblocks.
|
||
|
||
### Sink merge (D5c — key insight)
|
||
|
||
The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a **single softmax over `[S_comp, S_swa + attn_sink]`**.
|
||
|
||
One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed.
|
||
|
||
---
|
||
|
||
|
||
## Package structure
|
||
|
||
```
|
||
dsv4/
|
||
├── kernels/ Pure GPU code
|
||
│ ├── attention/ Production FMHA — 6-warp TMA multi-tile (.cuh + C-API .cu + op.py + production.py)
|
||
│ │ production.py is the entry point used by single_shot_inference.py
|
||
│ ├── gemm/ NVFP4 MoE GEMM (grouped, fused_swiglu, dense, scheduler)
|
||
│ ├── compressor/ CSA/HCA production compressor (production_compress.py → compressor_reduce.cu)
|
||
│ ├── indexer/ CSA indexer (stub; live path is inline in single_shot_inference.py)
|
||
│ ├── router/ Dense router decode + activation_topk
|
||
│ ├── cuda/ Raw .cu kernels (loader.py compiles on demand)
|
||
│ └── cache/ (stub; SWA/flush kernels are in cuda/)
|
||
├── ops/ PyTorch ↔ kernel bridges
|
||
│ ├── quantize.py BF16 ↔ NVFP4, scale factor handling, QuantizedActivation
|
||
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
|
||
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
|
||
│ ├── custom_ops.py torch.library.custom_op registrations
|
||
│ ├── rope_cuda.py Forward + inverse RoPE (partial, last 64 dims)
|
||
│ └── router.py Router op bridge (dense + hash dispatch)
|
||
├── layers/ nn.Module-style components (used by single_shot_inference.py)
|
||
│ ├── linear.py Nvfp4Linear
|
||
│ ├── grouped_linear.py Nvfp4GroupedLinear (output projection)
|
||
│ ├── moe.py Nvfp4MoE (routed experts)
|
||
│ ├── shared_expert.py Nvfp4SharedExpert
|
||
│ ├── mhc.py mHCLayer (Sinkhorn-Knopp, residual mixing)
|
||
│ └── router.py Router (dense + hash modes)
|
||
├── model/
|
||
│ ├── config.py DSV4Config
|
||
│ └── sampler.py CUDASampler
|
||
├── reference/
|
||
│ └── single_shot_PYTORCH_REFERENCE.py PyTorch oracle for layer comparison tests
|
||
└── _archive/ Dead Lineage P code (model/dsv4.py, cache/*, layers/{attention,ffn,norm,embedding}, etc.)
|
||
Kept for reference; never imported by live code
|
||
```
|
||
|
||
**Live path:** `single_shot_inference.py` → `dsv4/layers/*` → `dsv4/ops/*` → `dsv4/kernels/**`
|
||
|
||
**Attention path:** `production.py` → `fmha_multitile_op.py` → `fmha_multitile_capi.cu` → `fmha_6warp_tma_multirow_multitile.cuh`
|
||
|
||
**Archived (Lineage P):** `dsv4/model/dsv4.py`, `dsv4/cache/*`, `dsv4/layers/{attention,ffn,norm,embedding}` — these were the vLLM/sglang integration surface but have 0 importers. See `_archive/` if needed.
|
||
|
||
---
|
||
|
||
## Workflow & test harness
|
||
|
||
### The non-negotiables
|
||
|
||
- **NEVER edit on the B200.** Always: edit locally → commit → push → pull on B200 → test.
|
||
- **NEVER raw SSH + direct command.** Always use the test harness scripts. They handle: killing hung processes, deleting stale logs, screen sessions that survive SSH drops, timeouts for hung kernels, and GPU cleanup.
|
||
- **ALWAYS verify hd=64 regression** (cos ~0.999998) after every FMHA change. If it regresses, the change is wrong. Revert.
|
||
- **NEVER touch drivers, kernels, firmware, or system packages** on the B200.
|
||
- **NEVER delete test files** in `tests/unit/` without explicit approval.
|
||
|
||
### Two harnesses: Python and CUDA
|
||
|
||
| Harness | For | Script | Screen name | Log file |
|
||
|---|---|---|---|---|
|
||
| Python | `test_*.py` files | `fire_b200_test` | `kernel-test` | `/tmp/kernel-test.log` |
|
||
| CUDA | `test_*.cu` files | `fire_b200_cuda_test` | `cuda-test` | `/tmp/cuda-test.log` |
|
||
|
||
Both harnesses follow the same discipline:
|
||
1. **Kill everything first** — old screen sessions, hanging GPU processes, stale binaries
|
||
2. **Delete all logs** — never debug from a previous run's log
|
||
3. **Clean git + pull** — no uncommitted B200 state
|
||
4. **Run in screen** — survives SSH drops, has a timeout
|
||
5. **One test at a time** — no parallel launches, ever
|
||
|
||
### Python test
|
||
|
||
```bash
|
||
# From local machine — auto-pushes, runs, polls, dumps log
|
||
# DEFAULT timeout: 600s (10 min). Override with all 4 args:
|
||
~/.openclaw/workspace/fire_b200_test <test_file> [screen_name] [log_file] [timeout_sec]
|
||
|
||
# Examples:
|
||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py
|
||
~/.openclaw/workspace/fire_b200_test tests/unit/test_degeneration_2_mhc_falsify.py kernel-test /tmp/kernel-test.log 1800
|
||
```
|
||
|
||
### CUDA test
|
||
|
||
```bash
|
||
# From local machine — compiles with nvcc, runs, polls, dumps log
|
||
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_fmha_sm100_standalone.cu
|
||
~/.openclaw/workspace/fire_b200_cuda_test tests/unit/test_tmem_minimal.cu 30 # custom timeout
|
||
```
|
||
|
||
### Check on a running test
|
||
|
||
```bash
|
||
# Check CUDA test log + screen status
|
||
~/.openclaw/workspace/check_b200_cuda
|
||
~/.openclaw/workspace/check_b200_cuda kill # kill a hung test
|
||
|
||
# Check Python test — SSH to B200 and tail the log:
|
||
ssh root@<B200> tail -f /tmp/kernel-test.log
|
||
```
|
||
|
||
### Manual B200 cycle (emergency only)
|
||
|
||
```bash
|
||
ssh root@<B200>
|
||
cd /root/dsv4-nvfp4-workspace/kernel && git pull
|
||
bash tests/run_test.sh tests/unit/test_<...>.py
|
||
bash tests/check_log.sh
|
||
```
|
||
|
||
### ⚠️ Test harness gotchas (READ THIS — cost real time)
|
||
|
||
1. **The timeout is the 4th argument, not the 2nd.**
|
||
- WRONG: `fire_b200_test test.py 1800` ← this makes `1800` the SCREEN NAME
|
||
- RIGHT: `fire_b200_test test.py kernel-test /tmp/kernel-test.log 1800`
|
||
- When you pass just a number as the 2nd arg, the screen gets a numeric name
|
||
and the harness can't kill the old `kernel-test` screen on the next run.
|
||
- **Always pass all 4 args** when you need a custom timeout.
|
||
|
||
2. **After a timeout, the harness kills the screen but NOT the GPU process.**
|
||
- The `timeout` command inside screen kills the shell, but CUDA processes survive.
|
||
- Before re-running, check: `ssh root@<B200> nvidia-smi --query-compute-apps=pid --format=csv,noheader`
|
||
- Kill stale processes: `kill -9 <pid>` for each GPU process listed
|
||
- Or: `for pid in $(nvidia-smi --query-compute-apps=pid --format=csv,noheader); do kill -9 $pid; done`
|
||
|
||
3. **After an OOM or crash, stale GPU processes WILL be left behind.**
|
||
- Always check `nvidia-smi` before running a new test after a failure.
|
||
- The harness kills `python.*test_` and `python.*inference` procs, but if the
|
||
process name doesn't match the pattern, it survives.
|
||
|
||
4. **Single-shot tests MUST use the harness too.**
|
||
- `single_shot_inference.py` is NOT a unit test, but it MUST be run via the harness.
|
||
- WRONG: ssh to B200 and run `python single_shot_inference.py` directly
|
||
- RIGHT: `fire_b200_test single_shot_inference.py kernel-test /tmp/kernel-test.log 1800 -- --max-tokens 512`
|
||
- Extra args after `--` are passed to the Python script.
|
||
- If the harness can't handle your use case, FIX THE HARNESS, don't bypass it.
|
||
|
||
5. **Weight loading + CuTeDSL compilation takes 5-10 minutes.**
|
||
- First FMHA call triggers JIT compile of CuTeDSL kernels.
|
||
- This is EXPECTED. Do NOT kill the process because it "seems stuck".
|
||
- Use 1800s (30 min) timeout for full-model tests.
|
||
|
||
6. **The screen name must match between runs.**
|
||
- The harness kills the old screen by name. If you used a different name last time,
|
||
the old screen survives and holds GPU memory.
|
||
- Always use `kernel-test` for Python tests and `cuda-test` for CUDA tests.
|
||
- If you accidentally used a numeric screen name, clean up manually:
|
||
`ssh root@<B200> screen -S <wrong_name> -X quit`
|
||
|
||
### Environment
|
||
|
||
- **B200 access**: see `MEMORY.md` (not committed).
|
||
- **venv**: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||
- **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel`
|
||
- **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||
- **vLLM** (modified for Blackwell): `/root/dsv4-nvfp4-workspace/vllm`
|
||
- **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||
- **Local CUTLASS clone**: `/home/openclaw/dev/cutlass`
|
||
|
||
---
|
||
|
||
## CuTeDSL constraints (read every session)
|
||
|
||
These are surface-level traps. Get them wrong and the kernel silently produces garbage, NaN, or "weakly congruent" at JIT compile time.
|
||
|
||
1. **TMA partition tensors have 4 modes**: `(((64,128),1), ?, KV_tiles, ?)`. `(None, 0, None, 0)` keeps mode 2 (KV tiles) free; `[None, kt]` indexes it. `(None, None, 0, 0)` silently pins mode 2 to 0 — multi-tile loads break invisibly.
|
||
|
||
2. **`vectorize=True` loops accept only load/store/print.** No `fmax`, no `cmpf`, no inner loops, no carry across iterations.
|
||
|
||
3. **`.reduce(cute.ReductionOp.MAX)` reduces the entire C-fragment to a scalar** — global, not per-row. Use a plain `range()` loop with `cute.arch.fmax` for per-row max.
|
||
|
||
4. **`cute.arch.fmax` is impure** for the vectorizer. Use it inside plain `range()`, never inside `vectorize=True`.
|
||
|
||
5. **Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips.
|
||
|
||
6. **CuTeDSL `if` blocks are separate MLIR regions.** Variables defined inside one `if` are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching.
|
||
|
||
7. **Guard dead code with `const_expr`.** CuTeDSL compiles both branches of Python `if`. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; without `const_expr`, MLIR chews on it.
|
||
|
||
8. **`tma_partition` and `flat_divide` may not survive inside `if warp_idx` blocks.** Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel calls `tma_partition` inside the epilogue warp's `if`, so this constraint may depend on context — print and verify.)
|
||
|
||
9. **TMEM allocation must be a power of 2.** Round up after summing column requirements.
|
||
|
||
10. **`composition` vs `logical_divide` produce different layouts** even when re-tiling the same tensor. `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. Copy atoms must match the tensor layout they were created with.
|
||
|
||
11. **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN.
|
||
|
||
12. **`St32x32bOp` must use Float32, not BFloat16.** BFloat16 causes illegal memory access.
|
||
|
||
13. **First PV must have `ACCUMULATE=False`.** Otherwise adds uninitialized TMEM contents to the output.
|
||
|
||
14. **`find_tmem_tensor_col_offset()` returns footprint size, not a safe offset.** Never use it as a TMEM placement.
|
||
|
||
15. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as `(hd, s_k)` MN-major inside CuTe via explicit `make_tensor` + `make_layout`.
|
||
|
||
---
|
||
|
||
## Lessons learned (the gold — read every session)
|
||
|
||
These cost real days to learn. They are listed in priority of how easy they are to repeat.
|
||
|
||
### Layout & TMA
|
||
|
||
- **TMA partition mode ordering** (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.7–0.9, never NaN — so you can ship it without knowing.
|
||
- **Square hides bugs.** (128,128) worked for every wrong approach to PV. Always test non-square shapes early.
|
||
- **Print the shapes always.** Reasoning about TMEM layouts or TMA mode counts without running `cute.printf(cute.shape(t))` inside `@cute.kernel` is how every multi-day debug starts. Shapes are ground truth.
|
||
- **`qk_mma_tiler` K-dim must equal `head_dim`**, not the MMA instruction's K sub-tile size. Hardcoding `qk_ik * 4 = 64` was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128.
|
||
|
||
### TMEM
|
||
|
||
- **Never assume TMEM round-trips are safe.** Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP.
|
||
- **FMHA P store uses QK C-fragment composition, not PV A-fragment.** Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
|
||
- **Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip the dual view.
|
||
- **TMEM round-trip mismatch with `epilogue_tma_store`**: `epilogue_tma_store` reads O from TMEM using `get_tmem_load_op`'s layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error.
|
||
- **The correction-epilog pattern is the fix.** TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results.
|
||
|
||
### CuTeDSL & MLIR
|
||
|
||
- **CuTeDSL `if` blocks create separate MLIR regions.** Variables defined in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` inside an `if warp_idx < mma_warp_id:` are not visible. Define unconditionally before any branching.
|
||
- **CuTeDSL compiles both branches of Python `if`.** Wrap mode-specific dead code in `const_expr(condition)` to eliminate it. Critical for O rescale (`n_kv_tiles > 1`), LSE compute (`not normalize`), SMEM-P path.
|
||
- **CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512.** Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours.
|
||
- **Don't mix Python loops and pipeline ops.** Python `for` unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer `cutlass.range(unroll=1)` for pipeline loops.
|
||
|
||
### Math & merging
|
||
|
||
- **External k_sub merge is mathematically impossible.** You cannot merge `softmax(Q_k0 @ K_k0^T) @ V` and `softmax(Q_k1 @ K_k1^T) @ V` into `softmax(Q @ K^T) @ V`. k_sub partitions are additive in **logit** space (`S = S_0 + S_1`); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over **different token sets** (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub.
|
||
- **D5 multi-tile KV merge IS valid.** Per-segment LSE + the formula `O = Σ exp(lse_i) · O_i / Σ exp(lse_i)` works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix.
|
||
- **Sink merge = single softmax over `[S_comp, S_swa + attn_sink]`.** The two-branch weighted merge formula in the paper is mathematically equivalent to adding `attn_sink` as a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d.
|
||
|
||
### Numerics
|
||
|
||
- **Always test at hd=64 first.** If the proven TMEM-P path regresses, nothing else matters.
|
||
- **`St32x32bOp` must be Float32**, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.)
|
||
- **First PV `ACCUMULATE=False`.** Otherwise sums uninitialized TMEM into the output and you see ~50% error.
|
||
|
||
### Workflow
|
||
|
||
- **Never edit on the B200.** Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost.
|
||
- **Print shapes inside `@cute.kernel` at trace time.** `print(f"tBgK shape: {cute.shape(tBgK)}")` runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL.
|
||
|
||
### SMEM budget
|
||
|
||
- **`pv_n_tile` is the easiest SMEM knob.** At hd > 256, reducing `pv_n_tile` from 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling.
|
||
- **`kv_stage` is the second-easiest.** Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB.
|
||
- **SMEM budget at various hd** (with `pv_n_tile=256` for hd≤256, `pv_n_tile=128` for hd>256, `kv_stage=2` for hd≤128 else 1):
|
||
|
||
| hd | sQ | sK | sV | sP | sC | Total | Limit |
|
||
|---:|---:|---:|---:|---:|---:|------:|------:|
|
||
| 64 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
|
||
| 128 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
|
||
| 256 | 64 KB | 64 KB | 64 KB | 0* | 32 KB | 224 KB | 232 KB |
|
||
| 512 | 64 KB | 64 KB | 32 KB | 0* | 32 KB | 192 KB | 232 KB |
|
||
|
||
*TMEM-P path: sP allocation skipped via `const_expr` conditional.
|
||
|
||
---
|
||
|
||
## Reference
|
||
|
||
- **DeepSeek V4 paper**: `DeepSeek_V4.pdf` in the repo root.
|
||
- **DeepGEMM** (V4-aligned reference kernels): https://github.com/deepseek-ai/DeepGEMM
|
||
- **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (B200) or `/home/openclaw/dev/cutlass` (local).
|
||
- **Reference oracles**: `dsv4/reference/` (PyTorch FP32 — slow, never imported by production code). |