198 lines
9.7 KiB
Markdown
198 lines
9.7 KiB
Markdown
# NVFP4 MegaMoE Kernel
|
||
|
||
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead.
|
||
|
||
## ⚠️ READ THIS FIRST — THE #1 RULE
|
||
|
||
**YOU MUST BUILD YOUR OWN KERNELS. ALL OF THEM. DO NOT PATCH vLLM.**
|
||
|
||
Mike was right — we need our own kernels. Not just for the NVFP4 GEMMs, but for the **ENTIRE attention pipeline**. The current approach of patching individual vLLM functions is a house of cards. Every patch leads to another crash, every workaround reveals three more broken things. FlashMLA, fp8_ds_mla, the fused C++ kernels, the Triton compressor, the indexer — they're all deeply coupled. You cannot swap one piece and expect the rest to work.
|
||
|
||
**THE ONLY PATH FORWARD:**
|
||
1. Build CuTeDSL kernels for EVERYTHING — attention, KV cache, RoPE, the whole stack
|
||
2. Test each kernel standalone on the B200 venv BEFORE touching the container
|
||
3. Wire them together into a proper vLLM attention backend
|
||
4. THEN and ONLY THEN test in the container
|
||
|
||
**DO NOT:**
|
||
- ❌ Try to patch vLLM's FlashMLA code to "work" on Blackwell
|
||
- ❌ Use pure PyTorch as a "temporary workaround" — it produces garbage
|
||
- ❌ Skip the KV cache write and hope for the best
|
||
- ❌ Assume you can mix our kernels with vLLM's existing attention backend
|
||
- ❌ Touch the container until ALL kernels pass standalone tests
|
||
|
||
**DO:**
|
||
- ✅ Build CuTeDSL kernels in `cutedsl/`
|
||
- ✅ Test each one in `tests/` on the B200 venv
|
||
- ✅ Compare against BF16 reference (cosine >= 0.98 or it's broken)
|
||
- ✅ Wire them into a proper attention backend class
|
||
- ✅ Only test in the container once everything passes standalone
|
||
|
||
---
|
||
|
||
## What This Is
|
||
|
||
A native NVFP4 inference stack for DeepSeek-V4:
|
||
|
||
**MoE Experts** — CuTeDSL ScaledGroupedGemmKernel ✅:
|
||
```
|
||
BF16 input → quantize to NVFP4
|
||
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
|
||
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
|
||
Re-quantize to NVFP4
|
||
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
|
||
Scatter with routing weights → BF16 output
|
||
```
|
||
|
||
**Attention Projections** — CuTeDSL NVFP4 GEMM ✅:
|
||
- `q_a_proj`, `q_b_proj`, `kv_proj`, `wo_b_proj` — native NVFP4, cosine 0.995 vs BF16
|
||
- `wo_a` — BF16 BMM (o_a_proj weights are BF16 in checkpoint)
|
||
- All verified with `tests/test_full_layer_b200.py`
|
||
|
||
**Shared Experts** — CuTeDSL NVFP4 GEMM ✅:
|
||
- `gate_up_proj`, `down_proj` — native NVFP4, cosine 0.990 vs BF16
|
||
|
||
**Attention Compute** — 🔧 NEEDS CuTeDSL:
|
||
- Pure PyTorch SDPA produces garbage in the container
|
||
- FlashMLA is broken on Blackwell
|
||
- Must build CuTeDSL kernels for Q×K, attn×V, KV cache, RoPE
|
||
|
||
**KV Cache** — 🔧 NEEDS CuTeDSL:
|
||
- The fp8_ds_mla format is FlashMLA-specific (584 bytes per token)
|
||
- Must build our own NVFP4 KV cache with our own format
|
||
|
||
## Architecture: DeepSeek-V4-Pro
|
||
|
||
**CSA + HCA + mHC** (NOT MLA — vLLM misnames it "MLA" in code):
|
||
|
||
- **CSA (Compress Ratio 4)**: Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
|
||
- **HCA (Compress Ratio 128)**: Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build.
|
||
- **mHC**: Manifold-Constrained Hyper-Connections — replaces standard residual connections. Learned mixing with Sinkhorn normalization.
|
||
- **SWA**: Sliding Window Attention — local window (compress_ratio=0, last layer only)
|
||
|
||
**Compress Ratios (from config.json):**
|
||
```
|
||
Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA)
|
||
Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA)
|
||
```
|
||
|
||
## Current Status
|
||
|
||
### ✅ Working (verified on B200 standalone tests)
|
||
|
||
| Component | Test | Cosine vs BF16 |
|
||
|-----------|------|----------------|
|
||
| CuTeDSL NVFP4 Linear (q_a, kv, q_b, wo_b) | `test_full_layer_b200.py` | 0.994+ |
|
||
| CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) | `layertest.py` | 0.988 |
|
||
| FP8 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9997 |
|
||
| NVFP4 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9943 |
|
||
| Paged KV cache read/write | `test_kv_cache_b200.py` | 1.0 |
|
||
| FP8 KV → full attention | `test_kv_cache_b200.py` | 0.9997 |
|
||
| CSA sparse attention (cr=4) | `test_sparse_attn_b200.py` | works, no NaN |
|
||
| HCA sparse attention (cr=128) | `test_sparse_attn_b200.py` | works, no NaN |
|
||
| Merged CSA+SWA attention | `test_sparse_attn_b200.py` | works, no NaN |
|
||
| Full attention pipeline (all layer types) | `test_v4_attention_b200.py` | 0.981–0.995 |
|
||
| RoPE (GPT-J) | `test_v4_attention_b200.py` | works |
|
||
| Inverse RoPE + o_a BMM | `test_v4_attention_b200.py` | works |
|
||
|
||
### 🔧 Needs CuTeDSL Kernels
|
||
|
||
1. **Attention Q×K^T** — BF16 matmul works standalone, but NVFP4 GEMM too lossy (cosine 0.86). Keep Q×K in BF16.
|
||
2. **KV Cache Write** — need CuTeDSL kernel that does: RoPE → fp8 quant → paged cache insert
|
||
3. **KV Cache Read** — need CuTeDSL kernel that does: paged cache read → fp8 dequant
|
||
4. **Fused Q-norm + RoPE** — currently pure PyTorch (works, slow)
|
||
5. **Fused inverse RoPE + o_a BMM** — currently pure PyTorch (works)
|
||
|
||
### ❌ Does NOT Work
|
||
|
||
- **NVFP4 Q×K^T GEMM** — cosine 0.86, too lossy for attention scores. Keep attention in BF16.
|
||
- **Patching vLLM's FlashMLA path** — house of cards. Don't do it.
|
||
- **Pure PyTorch SDPA in the container** — produces garbage because the KV cache isn't written and the pipeline is broken.
|
||
|
||
## Container Status
|
||
|
||
The container builds and starts successfully. The server accepts requests and generates tokens. But the output is empty/garbage because the Blackwell attention path is broken. Multiple patches were applied to get this far (KV cache page sizes, FlashMLA alignment, softmax_scale, compressor cache), but the fundamental problem remains: **you cannot half-ass the attention pipeline**.
|
||
|
||
## Test Files
|
||
|
||
| Test | What it does | Status |
|
||
|------|-------------|--------|
|
||
| `tests/test_full_layer_b200.py` | All NVFP4 projections vs BF16 | ✅ 0.994+ |
|
||
| `tests/layertest.py` | MoE layer test | ✅ 0.988 |
|
||
| `tests/cudagraph_test.py` | CUDAGraph compatibility | ✅ PASS |
|
||
| `tests/test_csa_attention_b200.py` | Full attention with SDPA | ✅ 0.988 |
|
||
| `tests/test_v4_attention_b200.py` | All 3 layer types (SWA, C128A, C4A) | ✅ 0.981-0.995 |
|
||
| `tests/test_kv_cache_b200.py` | FP8/NVFP4 KV cache + paged cache | ✅ 0.9997 |
|
||
| `tests/test_sparse_attn_b200.py` | CSA/HCA sparse + SWA merged | ✅ works |
|
||
| `tests/test_nvfp4_attn_gemm_b200.py` | NVFP4 Q×K^T GEMM | ❌ 0.86 (too lossy) |
|
||
|
||
## Project Structure
|
||
|
||
```
|
||
nvfp4-megamoe-kernel/
|
||
├── cutedsl/ # CuTeDSL kernel + bridge layer
|
||
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
|
||
│ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
|
||
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
|
||
│ ├── shared_expert_pipeline.py # Shared expert pipeline
|
||
│ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA — needs CuTeDSL)
|
||
│ ├── custom_ops.py # torch.autograd wrappers
|
||
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel
|
||
├── vllm/ # vLLM integration
|
||
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner
|
||
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod
|
||
│ ├── kernels/linear/nvfp4/cutedsl.py # vLLM kernel registration
|
||
│ └── patches/
|
||
│ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
|
||
│ ├── patch_kv_cache_utils.py # KV cache page size fix
|
||
│ ├── patch_swa_cache.py # SWA cache alignment fix
|
||
│ ├── patch_indexer_cache.py # Indexer cache alignment fix
|
||
│ ├── patch_compressor_cache.py # Compressor cache alignment fix
|
||
│ └── layers/
|
||
│ ├── csa_attention.py # BF16 SDPA (TEMPORARY — needs CuTeDSL)
|
||
│ └── ...
|
||
├── tests/ # Standalone tests (run on B200 venv)
|
||
└── Dockerfile # Container build
|
||
```
|
||
|
||
## Plan
|
||
|
||
### Phase 1: MoE Kernel ✅ DONE
|
||
### Phase 2: NVFP4 Linear Kernels ✅ DONE
|
||
### Phase 3: vLLM Integration ✅ DONE (NVFP4 linear + MoE working)
|
||
|
||
### Phase 4: CuTeDSL Attention Backend 🔧 NEXT — BUILD THE KERNELS
|
||
|
||
**STOP. READ THIS.**
|
||
|
||
Do NOT touch the vLLM container until ALL of these kernels pass standalone tests on the B200 venv. The container is a 14-minute build cycle. The venv gives you instant feedback. TEST FIRST.
|
||
|
||
**Kernels to build (in order):**
|
||
|
||
1. **KV Cache Write**: BF16 KV → apply RoPE → quantize to fp8 → write to paged cache
|
||
- Test: compare against BF16 reference (cosine >= 0.98 after dequant)
|
||
|
||
2. **KV Cache Read**: paged cache → dequant fp8 → BF16 KV with RoPE
|
||
- Test: write then read back, cosine >= 0.99
|
||
|
||
3. **BF16 Attention**: Q (with RoPE) × K^T → softmax → attn × V
|
||
- Keep this in BF16 (NVFP4 is too lossy for attention scores)
|
||
- Handle CSA sparse gather (attend to top-k indexed positions)
|
||
- Handle HCA sparse gather (attend to 1/128 positions)
|
||
- Handle SWA (sliding window, full causal within window)
|
||
- Test: compare against PyTorch SDPA reference (cosine >= 0.99)
|
||
|
||
4. **Full Attention Pipeline**: KV cache read → attention → inverse RoPE → o_a BMM
|
||
- Wire everything together
|
||
- Test: compare against BF16 reference (cosine >= 0.98)
|
||
|
||
5. **vLLM Backend**: Wrap as a proper AttentionBackend subclass
|
||
- Override `DeepseekSparseSWABackend` on Blackwell
|
||
- Handle the metadata, slot mapping, cache format
|
||
- ONLY THEN test in the container
|
||
|
||
### Phase 5: Production
|
||
- End-to-end benchmarking
|
||
- Optimize tile sizes
|
||
- Clean up
|