NVFP4 MegaMoE Kernel
Native NVFP4 inference stack for DeepSeek-V4 on NVIDIA Blackwell (SM100). CuTeDSL kernels for the entire model — MoE experts, shared experts, attention projections — running in native NVFP4 with zero dequantization overhead.
⚠️ THE #1 RULE
WE OWN ALL OUR KERNELS. WE DO NOT PATCH vLLM.
vLLM's internal kernels (FlashMLA, fp8_ds_mla, fused compressor, Triton indexer) are deeply coupled. You cannot swap one piece and expect the rest to work. We build our own CuTeDSL kernels, test standalone, then wire into vLLM as an attention backend.
Repository Layout
This repo (nvfp4-megamoe-kernel): The kernel library — CuTeDSL kernels, bridge layer, standalone tests.
vLLM fork (vllm-deepseekv4-nvfp4): The vLLM integration — model definition, weight loading, attention backend. Lives at /root/dsv4-nvfp4-workspace/vllm on the B200.
Workspace (/root/dsv4-nvfp4-workspace):
kernel/— clone of this repovllm/— clone of the vLLM fork
What We Have
✅ CuTeDSL NVFP4 Grouped GEMM (the building block)
ScaledGroupedGemmKernel in cutedsl/kernel/moe/torch_scaled_grouped_mm.py — a production-grade NVFP4 grouped GEMM kernel:
- 2D×3D scenario: A(M,K) × B(E,K,N) → C(M,N)
- Block-scaled: per-16-element FP8 scales on both A and B sides
- Global scales (per-expert) for full dynamic range
- Persistent scheduler, TMA pipelining, SMEM swizzle
- CUDAGraph-safe (workspace pre-allocated, no runtime allocations)
✅ Bridge Layer (cutedsl/bridge.py)
quantize_to_nvfp4()— BF16 → NVFP4 with global scalequantize_activation_nvfp4()— cudagraph-safe quantize (pre-computed gs)quantize_weight_to_nvfp4()— weight quantization (along K dim)interleave_l1_weights()/deinterleave_l1_weights()— gate/up interleave at granularity 8 BF16make_b_k_major()— B tensor stride conversionassemble_scales_2d_side()/assemble_scales_3d_side()— scale assembly + swizzlewarmup_compilation()/warmup_fused_swiglu_compilation()— eager JIT compilationrun_nvfp4_grouped_gemm()/run_fused_swiglu_grouped_gemm()— kernel entry points
✅ MoE Runner (cutedsl/runner.py)
CuTeDSLMoERunner — runs the MoE forward pass:
- Quantize input BF16 → NVFP4 (using pre-computed gs)
- L1 GEMM: NVFP4 × NVFP4 → BF16 (gate+up interleaved, de-interleave then split)
- SiLU(gate) * up → BF16 (PyTorch — being replaced by fused kernel)
- Re-quantize BF16 → NVFP4
- L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
- Scatter with routing weights
✅ NVFP4 Linear (cutedsl/nvfp4_linear.py)
CuTeDSLNvfp4Linear — single-expert NVFP4 GEMM for shared experts and attention projections.
✅ Fused SwiGLU Kernel (Stage 1: BF16 output)
fused_swiglu_grouped_mm.py — extends ScaledGroupedGemmKernel with a fused SwiGLU epilogue:
- Weight interleave: L1 gate/up weights interleaved at granularity 8 BF16
- epi_tile=(128, 8): each 8-wide subtile is pure gate or pure up
- Subtile-level pairing: even subtiles = gate (compute SiLU, save to register buffer), odd subtiles = up (load SiLU(gate) from buffer, compute silu(gate)*up)
- Stage 1 DONE: BF16 output with SwiGLU, cosine 0.977 vs BF16 reference
- Stage 2 NEXT: NVFP4 quantize in epilogue, direct FP4 TMA store for L2
Correctness Bugs Fixed (May 20, 2026)
All 5 bugs fixed, committed, pushed:
| Bug | Issue | Fix |
|---|---|---|
| 1 | _needs_token_refill myth — cute.compile doesn't corrupt GPU memory |
Removed hack, added warmup_compilation(), pre-allocated workspace per cache entry |
| 2 | Dequantize→requantize supposedly lossy | Verified 100% byte-identical round-trip. Deprecated prepare_weights_from_dequantized |
| 3 | clamp(min=1e-8) on zero blocks gives nonzero FP8 scale |
Detect zero blocks, force FP8 scale to exact 0 |
| 4 | Underflow blocks (amax < 6×2⁻⁹) get nonzero FP4 from div-by-tiny-number | Detect underflow blocks, zero x_norm before division |
| 5 | Expert counting materializes 18M bool tensor | torch.bincount replaces O(n×E) comparison |
Fused SwiGLU — How It Works
The Problem
The L1 GEMM produces (M, 2×intermediate) BF16 output with gate and up columns side by side. SwiGLU needs silu(gate)*up, producing (M, intermediate). In the unfused path, this requires:
- ~580MB BF16 write to GMEM (L1 output)
- ~290MB BF16 read back (for gate/up split + SiLU)
- 3 kernel launches + 12 quantize ops
The Solution: Granularity-8 Weight Interleave + Subtile Pairing
Key insight: With interleave_l1_weights(), gate and up weight columns are interleaved at granularity 8 BF16. In the GEMM output, every 8 BF16 columns alternate: [gate₀-₇, up₀-₇, gate₈-₁₅, up₈-₁₅, ...].
With epi_tile_n=8, each epilogue subtile covers exactly 8 BF16 N-columns. So each subtile is pure gate or pure up — no mixing. Even subtile indices = gate, odd = up.
The epilogue loop processes gate/up pairs:
for subtile_idx in range(subtile_cnt):
acc_vec = load_accumulator(subtile_idx)
if even (gate):
silu_result = silu(acc_vec)
silu_gate_buf = silu_result # save to register buffer
acc_vec_bf16 = silu_result
if odd (up):
gate_vals = silu_gate_buf # from previous iteration
acc_vec_bf16 = gate_vals * acc_vec # SwiGLU
store_to_smem(acc_vec_bf16)
tma_store_to_gmem()
No runtime conditional affects tensor structure. The silu_gate_buf is a register buffer initialized before the loop. Both branches produce acc_vec_bf16 of the same type.
The output has interleaved [silu(gate), silu(gate)*up] at granularity 8. De-interleave recovers the standard [silu(gate) | silu(gate)*up] layout. The up columns contain the SwiGLU result.
The //2 Bug in interleave_l1_weights
The original function had g = granularity_bf16 // 2, which is correct for K-axis interleave (where FP4 byte-packing gives 2 BF16 per element along K). But we interleave along N, where each N-column = 1 BF16 column. The //2 was a leftover that silently gave g=4 instead of g=8, producing granularity 4 instead of 8. Fixed: g = granularity_bf16 (no //2).
CuTeDSL Runtime Conditionals
CuTeDSL does support runtime conditionals on register tensors — the rule is that both branches must produce the same tensor type (shape, layout, dtype). The earlier "blocked by type system" framing was wrong. The real issue was that the old code applied SiLU to ALL positions (just SiLU, not SwiGLU) and used is_gate_subtile < num_gate_subtiles which doesn't work with interleaved weights. With epi_tile_n=8 and subtile-level pairing, the conditional is clean: both branches produce acc_vec_bf16 of the same BF16 type.
Fused SwiGLU — Remaining Steps
| Step | What | Status |
|---|---|---|
| 1 | Wire fused kernel into pipeline (skip BF16 GMEM round-trip) | 🔄 In progress |
| 2 | NVFP4 quantize in epilogue (per-16-element amax, FP8 SF, FP4 pack) | 🔨 Next |
| 3 | FP4 TMA store to padded L2 buffer | Not started |
| 4 | FP8 SF TMA store through blockscaled layout | Not started |
| 5 | End-to-end test with fused pipeline | Not started |
DeepSeek-V4 Architecture Notes
NOT MLA. DeepSeek-V4 uses:
- CSA (Compressed Sparse Attention, cr=4): KV compressed 4x, indexer finds top-k
- HCA (Heavily Compressed Attention, cr=128): KV compressed 128x, pre-computed indices
- SWA: Standard sliding window (window=128, last layer only)
- mHC: Manifold-Constrained Hyper-Connections — replaces residual connections
- 384 experts, top-6, intermediate=3072
Compress ratios by layer: alternating 128/4, layer 60 = 0 (SWA).
File Structure
cutedsl/
├── bridge.py # Quantization, layout, kernel launch
├── nvfp4_linear.py # Single-expert NVFP4 GEMM runner
├── runner.py # MoE grouped GEMM runner
├── blackwell_attention.py # KV cache + attention (standalone)
├── csa_attention.py # CSA/HCA attention
├── custom_ops.py # torch.autograd wrappers
├── moe_pipeline.py # Standalone test pipeline
└── kernel/moe/
├── torch_scaled_grouped_mm.py # ScaledGroupedGemmKernel (the GEMM)
└── fused_swiglu_grouped_mm.py # FusedSwiGLUScaledGroupedGemmKernel
tests/
├── layertest.py # MoE layer test (PASS, 0.988 cosine)
├── cudagraph_test.py # CUDAGraph test (PASS)
├── test_full_layer_b200.py # All NVFP4 projections (PASS, 0.994+)
├── test_v4_attention_b200.py # All 3 attention types (PASS)
├── test_kv_cache_b200.py # KV cache (PASS, 0.9997)
├── test_sparse_attn_b200.py # CSA/HCA (PASS)
├── test_decode_attention_b200.py # Prefill+decode (PASS, 0.9998)
└── ...
Key Lessons (Things We Fucked Up)
-
⛔ NEVER assume CuTeDSL GPU tensors survive JIT compilation.
cute.compilezeroes GPU memory. Keep index/mapping tensors on CPU. -
⛔ NEVER nuke working code without understanding why it exists. The cudagraph-safe functions exist because vLLM REQUIRES cudagraph.
-
⛔ NEVER fabricate facts from MEMORY.md. Verify what "works" means before citing it.
-
⛔ NEVER quantize a padded buffer and slice the output. Quantize compact data, scatter into padded layout.
-
⛔ Silent weight drops are deadly. vLLM's
if name not in params_dict: continueskips weights with no warning. Replace with hard RuntimeError. -
⛔ NVFP4 is NOT suitable for attention Q×K^T. Per-element dot products are too sensitive. Keep attention in BF16.
-
⛔ NEVER touch drivers, kernels, firmware, or system packages on the B200. The cluster costs millions. Always confirm with Mike.
-
⛔ CuTeDSL
ifbranches must produce the same tensor type. Both branches must yield identical (shape, layout, dtype). Initialize variables before theif— using values defined only inside a branch is not supported. -
⛔ The
//2in interleave was a K-axis leftover. FP4 packing is along K, not N. When interleaving along N,g = granularity_bf16(no//2). The bug silently gave granularity 4 instead of 8, which would have produced wrong register-level pairing. -
⛔ "SiLU on all positions" is NOT SwiGLU. SwiGLU pairs silu(gate)*up. Applying SiLU to the full (M, 2×intermediate) output is just SiLU, producing wrong results. The pairing must be explicit.