NVFP4 MegaMoE Kernel
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead.
What This Is
A native NVFP4 inference stack for DeepSeek-V4:
MoE Experts — CuTeDSL ScaledGroupedGemmKernel ✅:
BF16 input → quantize to NVFP4
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
Re-quantize → NVFP4
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
Scatter with routing weights → BF16 output
Attention Projections — CuTeDSL NVFP4 GEMM ✅:
q_a_proj,q_b_proj,kv_proj,wo_b_proj— native NVFP4, cosine 0.995 vs BF16wo_a— BF16 BMM (o_a_proj weights are BF16 in checkpoint)compressor.kv_proj,compressor.gate_proj— native NVFP4, cosine 0.995 vs BF16- All verified with
tests/test_full_layer_b200.py
Shared Experts — CuTeDSL NVFP4 GEMM ✅:
gate_up_proj,down_proj— native NVFP4, cosine 0.990 vs BF16
Attention Compute — NEEDS CuTeDSL NVFP4 🔧:
- Currently using pure PyTorch SDPA as a TEMPORARY workaround
- Q×K and attn×V are activation×activation matmuls that CAN be NVFP4
- FlashMLA (vLLM's CUDA kernel) is broken on Blackwell
- Plan: CuTeDSL NVFP4 attention kernel — quantize Q/K to NVFP4, use CuTeDSL GEMMs
KV Cache Write — NEEDS CuTeDSL NVFP4 🔧:
- The SWA KV cache uses
fp8_ds_mlapacked format (37376 bytes per slot, not 512) - C++ kernel
fused_deepseek_v4_qnorm_rope_kv_rope_quant_insertis broken on Blackwell - Currently skipped in Blackwell path (works for prefill, breaks decode)
- Plan: NVFP4 quant + paged cache insert in CuTeDSL
Architecture: DeepSeek-V4-Pro
CSA + HCA + mHC (NOT MLA — vLLM misnames it "MLA" in code):
- CSA (Compress Ratio 4): Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
- HCA (Compress Ratio 128): Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build.
- mHC: Manifold-Constrained Hyper-Connections — replaces standard residual connections. Learned mixing with Sinkhorn normalization.
- SWA: Sliding Window Attention — local window (compress_ratio=0, last layer only)
Compress Ratios (from config.json):
Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA)
Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA)
Checkpoint Key Names (different from vLLM's internal names):
q_a_proj, q_b_proj, kv_proj (NOT fused_wqa_wkv, wq_b)
q_a_norm (NOT q_norm)
attn_hc.fn/base/scale (MHC attention)
ffn_hc.fn/base/scale (MHC FFN)
compressor.kv_proj, compressor.gate_proj (CSA/HCA compressor)
compressor.position_bias
sinks (attn_sink)
Current Status: Attention + KV Cache Need CuTeDSL 🔧
What works (verified on B200):
- CuTeDSL NVFP4 linear kernels: cosine 0.989–0.999 vs BF16 ✅
- CuTeDSL NVFP4 MoE: cosine 0.988 ✅
- Full attention path with PyTorch SDPA: cosine 0.988 vs BF16 ✅
- MHC, RMS norm, RoPE (BF16), wo_a BMM, shared experts ✅
- Compressor + indexer (Triton, works on SM100) ✅
What's broken:
- FlashMLA CUDA kernel → garbage on Blackwell → model outputs immediate EOS
fused_deepseek_v4_qnorm_rope_kv_rope_quant_insertC++ kernel → crashes on Blackwell- Pure PyTorch SDPA is a TEMPORARY workaround — must replace with CuTeDSL NVFP4
What needs to be built:
1. CuTeDSL NVFP4 Attention Kernel
- Quantize Q and K to NVFP4 per-head
- Use CuTeDSL GEMM for Q×K and attn×V
- Support prefill (batched) and decode (single-token) paths
- Handle CSA sparse gather (attend to top-k positions only)
- This is exactly what FlashMLA does with FP8 — we just use NVFP4 instead
- Test first: build standalone test in
tests/with real weights
2. CuTeDSL NVFP4 KV Cache Insert
- Replace C++
fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert - Per-head RMS norm on Q + GPT-J RoPE on Q + RoPE on KV + NVFP4 quant + paged cache write
- The SWA cache uses
fp8_ds_mlapacked format: row width = 37376 bytes- Layout: [nope_dim FP8 values | rope_dim FP8 values | UE8M0 scale blocks]
- NOT just [head_dim] — it's a packed FP8 format with interleaved scales
- Option A: Understand and write the fp8_ds_mla format from CuTeDSL
- Option B: Use our own NVFP4 cache format (simpler, more efficient, but diverges from vLLM)
3. CuTeDSL Fused RoPE + Norm Kernel
- Currently pure PyTorch (works, but slow)
- Fuse: Q norm → RoPE → NVFP4 quant → all in one pass
- Same for KV side: RoPE → NVFP4 quant → cache write
4. CuTeDSL CSA Sparse Gather
- Currently
torch.gather(slow, not GPU-optimal) - CuTeDSL can do the gather + GEMM in one fused operation
- The whole point of CSA is sparse KV access — we should do it right
vLLM Integration
The Blackwell detection and dispatch is in vllm/patches/deepseek_v4_attention.py:
attention_impl()detects SM100+ →_attention_impl_blackwell()- Currently uses pure PyTorch (SDPA) — must replace with CuTeDSL
- The dispatch is INSIDE the
torch.ops.vllm.deepseek_v4_attentioncustom op boundary (important for torch.compile)
Config issues:
quant_method: modelopt→ vLLM uses ModelOpt's NVFP4 handler- Our CuTeDSL IS registered (via
register_cutedsl_kernel.py) and forced withVLLM_NVFP4_GEMM_BACKEND=cutedsl - FlashMLA hard assertion in
DeepseekV4MLAAttention.__init__— patched with_is_blackwellflag kv_cache_scheme: {"num_bits": 8, "type": "float"}→ FP8 KV cache → FlashMLA (broken on Blackwell)
Key discovery: warmup gs is irrelevant. CuTeDSL runner recomputes activation global scale per-call internally. Changing it 10x has zero effect on output (cosine 0.9993). The input_scale from the checkpoint is NOT the activation global scale — it's a calibration constant.
Test Files
| Test | What it does | Status |
|---|---|---|
tests/test_full_layer_b200.py |
All NVFP4 projections vs BF16 (layer 0) | ✅ All pass (0.989–0.999) |
tests/test_model_forward_b200.py |
Warmup gs vs dynamic gs diagnostic | ✅ Warmup gs irrelevant |
tests/test_csa_attention_b200.py |
Full attention path with SDPA | ✅ cosine 0.988 |
tests/layertest.py |
MoE layer test | ✅ cosine 0.988 |
tests/cudagraph_test.py |
CUDAGraph compatibility | ✅ PASS |
tests/test_shared_expert.py |
Shared expert standalone | ✅ cosine 0.990 |
Project Structure
nvfp4-megamoe-kernel/
├── cutedsl/ # CuTeDSL kernel + bridge layer
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
│ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
│ ├── shared_expert_pipeline.py # Shared expert pipeline (1-expert MoE variant)
│ ├── csa_attention.py # CSA/HCA attention (currently SDPA, needs CuTeDSL)
│ ├── custom_ops.py # torch.autograd wrappers for compile boundary
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel
├── vllm/ # vLLM integration
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner — cudagraph-safe MoE kernel
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod — vLLM quant method
│ ├── kernels/linear/nvfp4/cutedsl.py # CuTeDSLNvFp4LinearKernel — vLLM kernel registration
│ └── patches/
│ ├── deepseek_v4.py # Model patch (NVFP4 native, MHC, MoE)
│ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
│ ├── layers/
│ │ ├── mhc.py # MHC pure PyTorch (replaces TileLang)
│ │ ├── csa_attention.py # CSA attention (TEMPORARY — needs CuTeDSL)
│ │ └── deepseek_compressor.py # Compressor (Triton, works on SM100)
│ └── fused_moe/experts/cutedsl_moe.py # MoE CuTeDSL integration
├── tests/ # Standalone tests (run on B200 outside container)
└── Dockerfile # Container build
Plan
Phase 1: MoE Kernel ✅ DONE
- CuTeDSL ScaledGroupedGemmKernel with NVFP4
- Full pipeline: cosine 0.988, cudagraph-safe
Phase 2: NVFP4 Linear Kernels ✅ DONE
- All attention projections: cosine 0.995
- Shared experts: cosine 0.990
- Compressor projections: cosine 0.995
Phase 3: vLLM Integration ✅ DONE (with PyTorch fallback)
- CuTeDSL kernels registered and working for all NVFP4 linear layers
- Blackwell dispatch in attention_impl
- MHC pure PyTorch
- MoE CuTeDSL
Phase 4: CuTeDSL NVFP4 Attention 🔧 NEXT
- Replace pure PyTorch SDPA with CuTeDSL NVFP4 GEMMs for Q×K and attn×V
- NVFP4 KV cache insert (replace C++ kernel)
- Fused RoPE + norm + quant kernel
- CSA sparse gather in CuTeDSL
- Test each component standalone before integrating into vLLM
Phase 5: Production
- End-to-end benchmarking
- Optimize tile sizes for occupancy
- Clean up old C++ kernel code