# NVFP4 MegaMoE Kernel Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead. ## ⚠️ READ THIS FIRST — THE #1 RULE **YOU MUST BUILD YOUR OWN KERNELS. ALL OF THEM. DO NOT PATCH vLLM.** Mike was right — we need our own kernels. Not just for the NVFP4 GEMMs, but for the **ENTIRE attention pipeline**. The current approach of patching individual vLLM functions is a house of cards. Every patch leads to another crash, every workaround reveals three more broken things. FlashMLA, fp8_ds_mla, the fused C++ kernels, the Triton compressor, the indexer — they're all deeply coupled. You cannot swap one piece and expect the rest to work. **THE ONLY PATH FORWARD:** 1. Build CuTeDSL kernels for EVERYTHING — attention, KV cache, RoPE, the whole stack 2. Test each kernel standalone on the B200 venv BEFORE touching the container 3. Wire them together into a proper vLLM attention backend 4. THEN and ONLY THEN test in the container **DO NOT:** - ❌ Try to patch vLLM's FlashMLA code to "work" on Blackwell - ❌ Use pure PyTorch as a "temporary workaround" — it produces garbage - ❌ Skip the KV cache write and hope for the best - ❌ Assume you can mix our kernels with vLLM's existing attention backend - ❌ Touch the container until ALL kernels pass standalone tests **DO:** - ✅ Build CuTeDSL kernels in `cutedsl/` - ✅ Test each one in `tests/` on the B200 venv - ✅ Compare against BF16 reference (cosine >= 0.98 or it's broken) - ✅ Wire them into a proper attention backend class - ✅ Only test in the container once everything passes standalone --- ## What This Is A native NVFP4 inference stack for DeepSeek-V4: **MoE Experts** — CuTeDSL ScaledGroupedGemmKernel ✅: ``` BF16 input → quantize to NVFP4 L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up) SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here) Re-quantize to NVFP4 L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj) Scatter with routing weights → BF16 output ``` **Attention Projections** — CuTeDSL NVFP4 GEMM ✅: - `q_a_proj`, `q_b_proj`, `kv_proj`, `wo_b_proj` — native NVFP4, cosine 0.995 vs BF16 - `wo_a` — BF16 BMM (o_a_proj weights are BF16 in checkpoint) - All verified with `tests/test_full_layer_b200.py` **Shared Experts** — CuTeDSL NVFP4 GEMM ✅: - `gate_up_proj`, `down_proj` — native NVFP4, cosine 0.990 vs BF16 **Attention Compute** — 🔧 NEEDS CuTeDSL: - Pure PyTorch SDPA produces garbage in the container - FlashMLA is broken on Blackwell - Must build CuTeDSL kernels for Q×K, attn×V, KV cache, RoPE **KV Cache** — 🔧 NEEDS CuTeDSL: - The fp8_ds_mla format is FlashMLA-specific (584 bytes per token) - Must build our own NVFP4 KV cache with our own format ## Architecture: DeepSeek-V4-Pro **CSA + HCA + mHC** (NOT MLA — vLLM misnames it "MLA" in code): - **CSA (Compress Ratio 4)**: Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k. - **HCA (Compress Ratio 128)**: Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build. - **mHC**: Manifold-Constrained Hyper-Connections — replaces standard residual connections. Learned mixing with Sinkhorn normalization. - **SWA**: Sliding Window Attention — local window (compress_ratio=0, last layer only) **Compress Ratios (from config.json):** ``` Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA) Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA) ``` ## Current Status ### ✅ Working (verified on B200 standalone tests) | Component | Test | Cosine vs BF16 | |-----------|------|----------------| | CuTeDSL NVFP4 Linear (q_a, kv, q_b, wo_b) | `test_full_layer_b200.py` | 0.994+ | | CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) | `layertest.py` | 0.988 | | FP8 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9997 | | NVFP4 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9943 | | Paged KV cache read/write | `test_kv_cache_b200.py` | 1.0 | | FP8 KV → full attention | `test_kv_cache_b200.py` | 0.9997 | | CSA sparse attention (cr=4) | `test_sparse_attn_b200.py` | works, no NaN | | HCA sparse attention (cr=128) | `test_sparse_attn_b200.py` | works, no NaN | | Merged CSA+SWA attention | `test_sparse_attn_b200.py` | works, no NaN | | Full attention pipeline (all layer types) | `test_v4_attention_b200.py` | 0.981–0.995 | | RoPE (GPT-J) | `test_v4_attention_b200.py` | works | | Inverse RoPE + o_a BMM | `test_v4_attention_b200.py` | works | ### 🔧 Needs CuTeDSL Kernels 1. **Attention Q×K^T** — BF16 matmul works standalone, but NVFP4 GEMM too lossy (cosine 0.86). Keep Q×K in BF16. 2. **KV Cache Write** — need CuTeDSL kernel that does: RoPE → fp8 quant → paged cache insert 3. **KV Cache Read** — need CuTeDSL kernel that does: paged cache read → fp8 dequant 4. **Fused Q-norm + RoPE** — currently pure PyTorch (works, slow) 5. **Fused inverse RoPE + o_a BMM** — currently pure PyTorch (works) ### ❌ Does NOT Work - **NVFP4 Q×K^T GEMM** — cosine 0.86, too lossy for attention scores. Keep attention in BF16. - **Patching vLLM's FlashMLA path** — house of cards. Don't do it. - **Pure PyTorch SDPA in the container** — produces garbage because the KV cache isn't written and the pipeline is broken. ## Container Status The container builds and starts successfully. The server accepts requests and generates tokens. But the output is empty/garbage because the Blackwell attention path is broken. Multiple patches were applied to get this far (KV cache page sizes, FlashMLA alignment, softmax_scale, compressor cache), but the fundamental problem remains: **you cannot half-ass the attention pipeline**. ## Test Files | Test | What it does | Status | |------|-------------|--------| | `tests/test_full_layer_b200.py` | All NVFP4 projections vs BF16 | ✅ 0.994+ | | `tests/layertest.py` | MoE layer test | ✅ 0.988 | | `tests/cudagraph_test.py` | CUDAGraph compatibility | ✅ PASS | | `tests/test_csa_attention_b200.py` | Full attention with SDPA | ✅ 0.988 | | `tests/test_v4_attention_b200.py` | All 3 layer types (SWA, C128A, C4A) | ✅ 0.981-0.995 | | `tests/test_kv_cache_b200.py` | FP8/NVFP4 KV cache + paged cache | ✅ 0.9997 | | `tests/test_sparse_attn_b200.py` | CSA/HCA sparse + SWA merged | ✅ works | | `tests/test_nvfp4_attn_gemm_b200.py` | NVFP4 Q×K^T GEMM | ❌ 0.86 (too lossy) | ## Project Structure ``` nvfp4-megamoe-kernel/ ├── cutedsl/ # CuTeDSL kernel + bridge layer │ ├── bridge.py # Tensor layout conversion, quantization, kernel launch │ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner │ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter) │ ├── shared_expert_pipeline.py # Shared expert pipeline │ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA — needs CuTeDSL) │ ├── custom_ops.py # torch.autograd wrappers │ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel ├── vllm/ # vLLM integration │ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner │ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod │ ├── kernels/linear/nvfp4/cutedsl.py # vLLM kernel registration │ └── patches/ │ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch) │ ├── patch_kv_cache_utils.py # KV cache page size fix │ ├── patch_swa_cache.py # SWA cache alignment fix │ ├── patch_indexer_cache.py # Indexer cache alignment fix │ ├── patch_compressor_cache.py # Compressor cache alignment fix │ └── layers/ │ ├── csa_attention.py # BF16 SDPA (TEMPORARY — needs CuTeDSL) │ └── ... ├── tests/ # Standalone tests (run on B200 venv) └── Dockerfile # Container build ``` ## Plan ### Phase 1: MoE Kernel ✅ DONE ### Phase 2: NVFP4 Linear Kernels ✅ DONE ### Phase 3: vLLM Integration ✅ DONE (NVFP4 linear + MoE working) ### Phase 4: CuTeDSL Attention Backend 🔧 NEXT — BUILD THE KERNELS **STOP. READ THIS.** Do NOT touch the vLLM container until ALL of these kernels pass standalone tests on the B200 venv. The container is a 14-minute build cycle. The venv gives you instant feedback. TEST FIRST. **Kernels to build (in order):** 1. **KV Cache Write**: BF16 KV → apply RoPE → quantize to fp8 → write to paged cache - Test: compare against BF16 reference (cosine >= 0.98 after dequant) 2. **KV Cache Read**: paged cache → dequant fp8 → BF16 KV with RoPE - Test: write then read back, cosine >= 0.99 3. **BF16 Attention**: Q (with RoPE) × K^T → softmax → attn × V - Keep this in BF16 (NVFP4 is too lossy for attention scores) - Handle CSA sparse gather (attend to top-k indexed positions) - Handle HCA sparse gather (attend to 1/128 positions) - Handle SWA (sliding window, full causal within window) - Test: compare against PyTorch SDPA reference (cosine >= 0.99) 4. **Full Attention Pipeline**: KV cache read → attention → inverse RoPE → o_a BMM - Wire everything together - Test: compare against BF16 reference (cosine >= 0.98) 5. **vLLM Backend**: Wrap as a proper AttentionBackend subclass - Override `DeepseekSparseSWABackend` on Blackwell - Handle the metadata, slot mapping, cache format - ONLY THEN test in the container ### Phase 5: Production - End-to-end benchmarking - Optimize tile sizes - Clean up