237 lines
12 KiB
Markdown
237 lines
12 KiB
Markdown
# DSV4 Inference Kernel
|
||
|
||
## Architecture
|
||
|
||
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
|
||
|
||
```
|
||
DSV4 inference pipeline — component status
|
||
==========================================
|
||
|
||
Legend:
|
||
[✓] built and tested
|
||
[~] partial — reference or seam exists, native pending
|
||
[✗] to build
|
||
|
||
|
||
┌────────────────────────────────────┐
|
||
│ [✗] Embedding + mHC init │
|
||
│ token embed + n_hc=4 streams │
|
||
└────────────────┬───────────────────┘
|
||
│
|
||
▼
|
||
┌─ Transformer layer × L ──────────────────────────────────────────────┐
|
||
│ HCA on layers 0–1 of Pro, alternating CSA / HCA after │
|
||
│ │
|
||
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
|
||
│ │ [✓] Residual mHC pre + post mix │ │
|
||
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
|
||
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
|
||
│ │ [~] Token compressor CSA m=4 / HCA m′=128 │ │
|
||
│ │ [✗] Indexer + top-k CSA only, FP4 QK │ │
|
||
│ │ [~] FMHA core QK → online softmax → PV │ │
|
||
│ │ + SWA branch + sink merge │ │
|
||
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
|
||
│ └────────────────────────────────────────────────────────────────┘ │
|
||
│ │
|
||
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
|
||
│ │ [✓] Residual mHC pre + post mix │ │
|
||
│ │ [~] Pre-FFN norm RMSNorm │ │
|
||
│ │ [✗] Router sqrt(softplus) + topk + hash │ │
|
||
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
|
||
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
|
||
│ └────────────────────────────────────────────────────────────────┘ │
|
||
└──────────────────────────────────┬───────────────────────────────────┘
|
||
│
|
||
▼
|
||
┌──────────────────────────────────────────────────────────────────────┐
|
||
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
|
||
└──────────────────────────────────────────────────────────────────────┘
|
||
|
||
┌─ Supporting infrastructure ──────────────────────────────────────────┐
|
||
│ [✗] KV cache management │
|
||
│ • state cache: SWA window + uncompressed tail per layer │
|
||
│ • classical paged cache: lcm(m, m′) = 128 tokens per block │
|
||
│ • heterogeneous layout per layer │
|
||
└──────────────────────────────────────────────────────────────────────┘
|
||
|
||
|
||
Summary
|
||
-------
|
||
Built [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
|
||
shared expert
|
||
Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
|
||
pre-FFN norm
|
||
To build [✗] : 8 — embedding+init, indexer+top-k, router,
|
||
final norm, LM head, MTP, sampler, KV cache
|
||
```
|
||
|
||
---
|
||
|
||
## Status (May 22, 2026 — 09:40 UTC)
|
||
|
||
| Stage | Status | Description |
|
||
|-------|--------|-------------|
|
||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||
| C | ✅ WORKING | Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996 |
|
||
| C' | 🔨 NEXT | Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV |
|
||
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
|
||
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
|
||
|
||
---
|
||
|
||
## Package Structure
|
||
|
||
```
|
||
dsv4/
|
||
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
|
||
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
|
||
│ ├── attention/ FMHA kernel (stub — extraction is Stage E)
|
||
│ ├── compressor/ CSA/HCA token-level compressor
|
||
│ ├── decode/ Decode-time attention (sparse, SWA — future)
|
||
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
|
||
├── ops/ PyTorch ↔ kernel bridges
|
||
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
|
||
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
|
||
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
|
||
│ ├── custom_ops.py torch.library.custom_op registrations
|
||
│ ├── decode_sparse.py native_sparse_decode dispatcher
|
||
│ ├── decode_swa.py native_swa_decode dispatcher
|
||
│ ├── rope.py Forward + inverse RoPE
|
||
│ └── topk.py Python wrapper for sparse_topk_metadata.cu
|
||
├── layers/ nn.Module-style components
|
||
│ ├── linear.py Nvfp4Linear
|
||
│ ├── grouped_linear.py Nvfp4GroupedLinear
|
||
│ ├── moe.py Nvfp4MoE
|
||
│ ├── shared_expert.py Nvfp4SharedExpert
|
||
│ ├── mhc.py mHCLayer
|
||
│ └── (stubs: attention, ffn, router, norm, embedding)
|
||
├── model/ Model assembly (stubs — Phase 1)
|
||
├── cache/ KV cache infra (stubs — Phase 3)
|
||
├── loader/ Checkpoint I/O (stubs — Phase 1)
|
||
└── reference/ Slow PyTorch oracles (never imported by production code)
|
||
├── attention.py RoPE, KV cache, causal attention, SWA
|
||
├── csa_attention.py CSA/HCA sparse attention
|
||
├── compressor.py Compressor PyTorch example
|
||
└── moe_pipeline.py MoE pipeline reference
|
||
```
|
||
|
||
**Mental model:** `kernels/` → `ops/` → `layers/` → `model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
|
||
|
||
---
|
||
|
||
## Active Test Files
|
||
|
||
### FMHA (Stages A/B/C) — in `tests/unit/`
|
||
|
||
| File | Stage | Status |
|
||
|------|-------|--------|
|
||
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
|
||
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
|
||
| `test_fmha_v3_stage_c_full.py` | C | ✅ Real online softmax + O normalization, cosine 0.993-0.996 |
|
||
| `test_fmha_v3_stage_c_min.py` | C | 🔨 Early 12-warp pipeline (broken pipeline state) |
|
||
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
|
||
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
|
||
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
|
||
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
|
||
|
||
### MoE / GEMM — in `tests/unit/`
|
||
|
||
| File | What |
|
||
|------|------|
|
||
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
|
||
| `cudagraph_test.py` | Cudagraph capture + replay |
|
||
| `layertest.py` | Per-layer correctness |
|
||
| `test_custom_op.py` | torch.library custom ops |
|
||
| `test_compile_custom_op.py` | Compile + warmup |
|
||
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
|
||
| `test_interleave.py` | Gate/up weight interleaving |
|
||
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
|
||
| `test_fused_step1.py` | Fused SwiGLU GEMM |
|
||
|
||
### Archived Tests
|
||
|
||
`tests/archive/` contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.
|
||
|
||
---
|
||
|
||
## Stage C: Online Softmax — WORKING
|
||
|
||
### What We Have
|
||
|
||
**Working real softmax** in `test_fmha_v3_stage_c_full.py`: cosine 0.993–0.996 across 3 seeds.
|
||
|
||
### Current Architecture (6-warp)
|
||
|
||
Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue
|
||
Warp 4: MMA (QK→S, PV→O)
|
||
Warp 5: TMA (Q/K/V load)
|
||
|
||
### Target Architecture (12-warp, production)
|
||
|
||
Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max]
|
||
Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write
|
||
Warp 8: MMA — QK→S, PV→O with pipeline chaining
|
||
Warp 9: TMA — Q/K/V load
|
||
Warp 10: Epilogue — O SMEM→GMEM via TMA
|
||
Warp 11: Empty — tmem dealloc mbar init
|
||
|
||
Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction)
|
||
|
||
### CuTeDSL Constraints (hard-won)
|
||
|
||
1. `vectorize=True` loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry
|
||
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row. Use `cute.arch.fmax` element-wise instead
|
||
3. Dynamic control flow: variables need initial values BEFORE the flow starts
|
||
4. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
|
||
5. Carry variables (row_max, row_sum): cannot use `vectorize=True`
|
||
|
||
### Remaining for C' (Production Stage C)
|
||
|
||
1. Cross-warp reduction for row_max and row_sum
|
||
2. Correction warps for multi-tile KV (online O rescale in TMEM)
|
||
3. 12-warp layout with separate softmax/correction/epilogue warps
|
||
4. Per-row O normalization
|
||
|
||
### TMEM Layout
|
||
|
||
Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32)
|
||
|
||
Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer.
|
||
|
||
---
|
||
|
||
## Stage E: Production Kernel Extraction
|
||
|
||
When ready, extract from `test_fmha_v3.py` → `dsv4/kernels/attention/fmha.py`:
|
||
1. Clean `FmhaKernel` class with `@cute.jit __call__`, no hardcoded dimensions
|
||
2. Add real softmax (Stage C)
|
||
3. Add paged KV cache (Stage D)
|
||
5. Wrap as `torch.library.custom_op` in `dsv4/ops/`
|
||
6. Integrate with vLLM
|
||
|
||
---
|
||
|
||
## Key Lessons
|
||
|
||
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
|
||
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
|
||
3. **TMEM allocation must be power of 2.**
|
||
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
|
||
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
|
||
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
|
||
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
|
||
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
|
||
|
||
---
|
||
|
||
## Environment
|
||
|
||
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
|
||
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
|
||
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|