NVFP4 MegaMoE Kernel
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead.
⚠️ READ THIS FIRST — THE #1 RULE
YOU MUST BUILD YOUR OWN KERNELS. ALL OF THEM. DO NOT PATCH vLLM.
Mike was right — we need our own kernels. Not just for the NVFP4 GEMMs, but for the ENTIRE attention pipeline. The current approach of patching individual vLLM functions is a house of cards. Every patch leads to another crash, every workaround reveals three more broken things. FlashMLA, fp8_ds_mla, the fused C++ kernels, the Triton compressor, the indexer — they're all deeply coupled. You cannot swap one piece and expect the rest to work.
THE ONLY PATH FORWARD:
- Build CuTeDSL kernels for EVERYTHING — attention, KV cache, RoPE, the whole stack
- Test each kernel standalone on the B200 venv BEFORE touching the container
- Wire them together into a proper vLLM attention backend
- THEN and ONLY THEN test in the container
DO NOT:
- ❌ Try to patch vLLM's FlashMLA code to "work" on Blackwell
- ❌ Use pure PyTorch as a "temporary workaround" — it produces garbage
- ❌ Skip the KV cache write and hope for the best
- ❌ Assume you can mix our kernels with vLLM's existing attention backend
- ❌ Touch the container until ALL kernels pass standalone tests
DO:
- ✅ Build CuTeDSL kernels in
cutedsl/ - ✅ Test each one in
tests/on the B200 venv - ✅ Compare against BF16 reference (cosine >= 0.98 or it's broken)
- ✅ Wire them into a proper attention backend class
- ✅ Only test in the container once everything passes standalone
What This Is
A native NVFP4 inference stack for DeepSeek-V4:
MoE Experts — CuTeDSL ScaledGroupedGemmKernel ✅:
BF16 input → quantize to NVFP4
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
Re-quantize to NVFP4
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
Scatter with routing weights → BF16 output
Attention Projections — CuTeDSL NVFP4 GEMM ✅:
q_a_proj,q_b_proj,kv_proj,wo_b_proj— native NVFP4, cosine 0.995 vs BF16wo_a— BF16 BMM (o_a_proj weights are BF16 in checkpoint)- All verified with
tests/test_full_layer_b200.py
Shared Experts — CuTeDSL NVFP4 GEMM ✅:
gate_up_proj,down_proj— native NVFP4, cosine 0.990 vs BF16
Attention Compute — 🔧 NEEDS CuTeDSL:
- Pure PyTorch SDPA produces garbage in the container
- FlashMLA is broken on Blackwell
- Must build CuTeDSL kernels for Q×K, attn×V, KV cache, RoPE
KV Cache — 🔧 NEEDS CuTeDSL:
- The fp8_ds_mla format is FlashMLA-specific (584 bytes per token)
- Must build our own NVFP4 KV cache with our own format
Architecture: DeepSeek-V4-Pro
CSA + HCA + mHC (NOT MLA — vLLM misnames it "MLA" in code):
- CSA (Compress Ratio 4): Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
- HCA (Compress Ratio 128): Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build.
- mHC: Manifold-Constrained Hyper-Connections — replaces standard residual connections. Learned mixing with Sinkhorn normalization.
- SWA: Sliding Window Attention — local window (compress_ratio=0, last layer only)
Compress Ratios (from config.json):
Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA)
Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA)
Current Status
✅ Working (verified on B200 standalone tests)
| Component | Test | Cosine vs BF16 |
|---|---|---|
| CuTeDSL NVFP4 Linear (q_a, kv, q_b, wo_b) | test_full_layer_b200.py |
0.994+ |
| CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) | layertest.py |
0.988 |
| FP8 KV quantize/dequant | test_kv_cache_b200.py |
0.9997 |
| NVFP4 KV quantize/dequant | test_kv_cache_b200.py |
0.9943 |
| Paged KV cache read/write | test_kv_cache_b200.py |
1.0 |
| FP8 KV → full attention | test_kv_cache_b200.py |
0.9997 |
| CSA sparse attention (cr=4) | test_sparse_attn_b200.py |
works, no NaN |
| HCA sparse attention (cr=128) | test_sparse_attn_b200.py |
works, no NaN |
| Merged CSA+SWA attention | test_sparse_attn_b200.py |
works, no NaN |
| Full attention pipeline (all layer types) | test_v4_attention_b200.py |
0.981–0.995 |
| RoPE (GPT-J) | test_v4_attention_b200.py |
works |
| Inverse RoPE + o_a BMM | test_v4_attention_b200.py |
works |
🔧 Needs CuTeDSL Kernels
- Attention Q×K^T — BF16 matmul works standalone, but NVFP4 GEMM too lossy (cosine 0.86). Keep Q×K in BF16.
- KV Cache Write — need CuTeDSL kernel that does: RoPE → fp8 quant → paged cache insert
- KV Cache Read — need CuTeDSL kernel that does: paged cache read → fp8 dequant
- Fused Q-norm + RoPE — currently pure PyTorch (works, slow)
- Fused inverse RoPE + o_a BMM — currently pure PyTorch (works)
❌ Does NOT Work
- NVFP4 Q×K^T GEMM — cosine 0.86, too lossy for attention scores. Keep attention in BF16.
- Patching vLLM's FlashMLA path — house of cards. Don't do it.
- Pure PyTorch SDPA in the container — produces garbage because the KV cache isn't written and the pipeline is broken.
Container Status
The container builds and starts successfully. The server accepts requests and generates tokens. But the output is empty/garbage because the Blackwell attention path is broken. Multiple patches were applied to get this far (KV cache page sizes, FlashMLA alignment, softmax_scale, compressor cache), but the fundamental problem remains: you cannot half-ass the attention pipeline.
Test Files
| Test | What it does | Status |
|---|---|---|
tests/test_full_layer_b200.py |
All NVFP4 projections vs BF16 | ✅ 0.994+ |
tests/layertest.py |
MoE layer test | ✅ 0.988 |
tests/cudagraph_test.py |
CUDAGraph compatibility | ✅ PASS |
tests/test_csa_attention_b200.py |
Full attention with SDPA | ✅ 0.988 |
tests/test_v4_attention_b200.py |
All 3 layer types (SWA, C128A, C4A) | ✅ 0.981-0.995 |
tests/test_kv_cache_b200.py |
FP8/NVFP4 KV cache + paged cache | ✅ 0.9997 |
tests/test_sparse_attn_b200.py |
CSA/HCA sparse + SWA merged | ✅ works |
tests/test_nvfp4_attn_gemm_b200.py |
NVFP4 Q×K^T GEMM | ❌ 0.86 (too lossy) |
Project Structure
nvfp4-megamoe-kernel/
├── cutedsl/ # CuTeDSL kernel + bridge layer
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
│ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
│ ├── shared_expert_pipeline.py # Shared expert pipeline
│ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA — needs CuTeDSL)
│ ├── custom_ops.py # torch.autograd wrappers
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel
├── vllm/ # vLLM integration
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod
│ ├── kernels/linear/nvfp4/cutedsl.py # vLLM kernel registration
│ └── patches/
│ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
│ ├── patch_kv_cache_utils.py # KV cache page size fix
│ ├── patch_swa_cache.py # SWA cache alignment fix
│ ├── patch_indexer_cache.py # Indexer cache alignment fix
│ ├── patch_compressor_cache.py # Compressor cache alignment fix
│ └── layers/
│ ├── csa_attention.py # BF16 SDPA (TEMPORARY — needs CuTeDSL)
│ └── ...
├── tests/ # Standalone tests (run on B200 venv)
└── Dockerfile # Container build
Plan
Phase 1: MoE Kernel ✅ DONE
Phase 2: NVFP4 Linear Kernels ✅ DONE
Phase 3: vLLM Integration ✅ DONE (NVFP4 linear + MoE working)
Phase 4: CuTeDSL Attention Backend 🔧 NEXT — BUILD THE KERNELS
STOP. READ THIS.
Do NOT touch the vLLM container until ALL of these kernels pass standalone tests on the B200 venv. The container is a 14-minute build cycle. The venv gives you instant feedback. TEST FIRST.
Kernels to build (in order):
-
KV Cache Write: BF16 KV → apply RoPE → quantize to fp8 → write to paged cache
- Test: compare against BF16 reference (cosine >= 0.98 after dequant)
-
KV Cache Read: paged cache → dequant fp8 → BF16 KV with RoPE
- Test: write then read back, cosine >= 0.99
-
BF16 Attention: Q (with RoPE) × K^T → softmax → attn × V
- Keep this in BF16 (NVFP4 is too lossy for attention scores)
- Handle CSA sparse gather (attend to top-k indexed positions)
- Handle HCA sparse gather (attend to 1/128 positions)
- Handle SWA (sliding window, full causal within window)
- Test: compare against PyTorch SDPA reference (cosine >= 0.99)
-
Full Attention Pipeline: KV cache read → attention → inverse RoPE → o_a BMM
- Wire everything together
- Test: compare against BF16 reference (cosine >= 0.98)
-
vLLM Backend: Wrap as a proper AttentionBackend subclass
- Override
DeepseekSparseSWABackendon Blackwell - Handle the metadata, slot mapping, cache format
- ONLY THEN test in the container
- Override
Phase 5: Production
- End-to-end benchmarking
- Optimize tile sizes
- Clean up