Files
nvfp4-megamoe-kernel/README.md

233 lines
12 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# DSV4 Inference Kernel
## Architecture
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
```
DSV4 inference pipeline — component status
==========================================
Legend:
[✓] built and tested
[~] partial — reference or seam exists, native pending
[✗] to build
┌────────────────────────────────────┐
│ [✗] Embedding + mHC init │
│ token embed + n_hc=4 streams │
└────────────────┬───────────────────┘
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after │
│ │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [~] Token compressor CSA m=4 / HCA m=128 │ │
│ │ [✗] Indexer + top-k CSA only, FP4 QK │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Pre-FFN norm RMSNorm │ │
│ │ [✗] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────┬───────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘
┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management │
│ • state cache: SWA window + uncompressed tail per layer │
│ • classical paged cache: lcm(m, m) = 128 tokens per block │
│ • heterogeneous layout per layer │
└──────────────────────────────────────────────────────────────────────┘
Summary
-------
Built [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert
Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
pre-FFN norm
To build [✗] : 8 — embedding+init, indexer+top-k, router,
final norm, LM head, MTP, sampler, KV cache
```
---
## Status (May 21, 2026 — 17:30 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | 🔨 IN PROGRESS | Real softmax: row max, exp, rescale, row sum (kernel written, needs test harness) |
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
---
## Package Structure
```
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel (stub — extraction is Stage E)
│ ├── compressor/ CSA/HCA token-level compressor
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
│ └── topk.py Python wrapper for sparse_topk_metadata.cu
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
│ └── (stubs: attention, ffn, router, norm, embedding)
├── model/ Model assembly (stubs — Phase 1)
├── cache/ KV cache infra (stubs — Phase 3)
├── loader/ Checkpoint I/O (stubs — Phase 1)
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
├── compressor.py Compressor PyTorch example
└── moe_pipeline.py MoE pipeline reference
```
**Mental model:** `kernels/``ops/``layers/``model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
---
## Active Test Files
### FMHA (Stages A/B/C) — in `tests/unit/`
| File | Stage | Status |
|------|-------|--------|
| `test_fmha_v3.py` | A+B | ✅ Full QK→softmax→PV, cosine 0.999999 |
| `test_fmha_v3_softmax.py` | C | 🔨 Online softmax kernel (needs test harness) |
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
### MoE / GEMM — in `tests/unit/`
| File | What |
|------|------|
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
| `cudagraph_test.py` | Cudagraph capture + replay |
| `layertest.py` | Per-layer correctness |
| `test_custom_op.py` | torch.library custom ops |
| `test_compile_custom_op.py` | Compile + warmup |
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
| `test_interleave.py` | Gate/up weight interleaving |
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
| `test_fused_step1.py` | Fused SwiGLU GEMM |
### Archived Tests
`tests/archive/` contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.
---
## Stage C: Online Softmax
### What We Have
Identity softmax in `test_fmha_v3.py`: load S FP32 → convert BF16 → store P. Proves TMEM pipeline works.
### What We Are Building
Online softmax in `test_fmha_v3_softmax.py` (kernel written, no test runner yet):
```
For each KV tile:
1. QK → S (FP32 in TMEM)
2. tile_max = max(S[j,:])
3. new_max = max(old_max, tile_max)
4. O *= exp(old_max - new_max) ← TMEM rescale
5. P = exp2((S - new_max) * scale) ← exp2 with 1/sqrt(d) * log2(e)
6. Store P to TMEM (FMHA pattern)
7. row_sum = row_sum * exp(old_max - new_max) + sum(P)
8. PV: O += P @ V
After all tiles:
9. O /= row_sum ← final TMEM normalization
```
### Key Implementation Details
- **Row max:** `tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)` per tile
- **O rescale:** Load O from TMEM, multiply by `exp2(old_max - new_max)`, store back (16-col tiles via `Ld32x32b/St32x32b`)
- **P computation:** `exp2((S - row_max) * scale)` where `scale = 1/sqrt(HEAD_DIM) * log2(e)`
- **Row sum:** Packed `f32x2` reduction using `cute.arch.add_packed_f32x2` (4 unroll, 2-wide)
- **Final norm:** Load O, multiply by `1/row_sum`, store (same TMEM load/store path)
### TMEM Layout (Current — Stage B)
```
Col: 0 32 64 96 128 192 256
|---- S ----|---- P ----| |---- O ----|
| QK acc | Softmax P | (gap) | PV acc |
| 128 FP32 | 64 FP32 | 32 col | 64 FP32 |
```
For Stage C, row_max/row_sum are per-thread FP32 scalars (not in TMEM). Future stages may need TMEM-backed state for wider tiles.
---
## Stage E: Production Kernel Extraction
When ready, extract from `test_fmha_v3.py``dsv4/kernels/attention/fmha.py`:
1. Clean `FmhaKernel` class with `@cute.jit __call__`, no hardcoded dimensions
2. Add real softmax (Stage C)
3. Add paged KV cache (Stage D)
5. Wrap as `torch.library.custom_op` in `dsv4/ops/`
6. Integrate with vLLM
---
## Key Lessons
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
3. **TMEM allocation must be power of 2.**
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
---
## Environment
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`