Replaces heartbeat prints with a clean tqdm bar: Loading Native NVFP4 Expert Weights: 50%|██████████░░| 480/960
NVFP4 MegaMoE Kernel
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, and attention projections — runs in native NVFP4 with zero dequantization overhead.
What This Is
A native NVFP4 inference stack for DeepSeek-V4:
MoE Experts — CuTeDSL ScaledGroupedGemmKernel (our work):
BF16 input → quantize to NVFP4
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
Re-quantize → NVFP4
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
Scatter with routing weights → BF16 output
Attention Projections — FlashInferCutlassNvFp4LinearKernel (vLLM built-in):
wq_b,wo_b,fused_wqa_wkv— native NVFP4, no conversionwo_a— NVFP4→FP8 forfp8_einsum(only attention weight that needs conversion)- Compressor — BF16 (weight_loader stacking issue, small matmul)
Shared Experts — FlashInferCutlassNvFp4LinearKernel (vLLM built-in):
gate_up_proj,down_proj— native NVFP4
Both GEMM types use float4_e2m1fn_x2 for weights, float8_e4m3fn for block scales, float32 for global scales. BF16 is used only for SiLU activation, the final MoE scatter, and the compressor — the minimum possible.
How We Got Here
The C++ CUTLASS Kernel Was Broken
The original kernel was a C++ .cu file using CUTLASS's C++ API directly. It passed all the simple tests (uniform data → exact output, SF remap verifier → 0 errors) but produced cosine 0.05 with real random data. After weeks of debugging the SF remap (8+ iterations, all producing the same 0.2 cosine against a wrong reference), we discovered:
-
The BF16 reference comparison was wrong — our Python dequantization didn't match CUTLASS's internal FP4 handling. A wrong reference is worse than no reference. We chased ghosts through 8+ SF remap rewrites because the 0.2 cosine was never about the remap.
-
The C++ CUTLASS kernel misinterpreted FP4 data — even with SF remap verified correct (0 byte errors), the GEMM produced garbage with non-uniform data. The issue was in how CUTLASS's C++ API handles FP4 packing/tiling internally — something we couldn't easily debug or fix.
-
The checkpoint
input_scalewas a red herring — we tried using the checkpoint's calibration scale as the activation normalization scale. It saturated all block scales to 448.0 (max float8). Theinput_scaleis a calibration constant for alpha computation, not a normalization scale.
The CuTeDSL Kernel Works
NVIDIA's CuTeDSL approach (Python-based CUTLASS kernels compiled via MLIR → PTX) is what the CUTLASS team recommends for Blackwell. Their official MoE scaled grouped GEMM example (torch_scaled_grouped_mm.py) supports NVFP4 out of the box. We adapted it.
Results with real DeepSeek-V4 layer 0 weights:
- L1 GEMM alone: cosine 0.995
- Full MoE pipeline (L1→SiLU→L2→scatter): cosine 0.989
- Weight loading: 0% loss — direct uint8→float4_e2m1fn_x2 view-cast, bit-identical to checkpoint
- Activation quantization: ~1.1% cosine loss (dynamic BF16→NVFP4 — inherent to the format, unavoidable)
- GEMM kernel: 0% loss (CuTeDSL is correct)
The 0.989 cosine is entirely from activation quantization. The weights are bit-identical to the checkpoint — no BF16 round-trip, no precision loss.
The Dequant→Requant Anti-Pattern
Early versions dequantized all NVFP4 weights to BF16, then let vLLM's FlashInferCutlassNvFp4LinearKernel requantize them back to NVFP4 at inference time. This:
- Wasted 5 minutes on load doing NVFP4→BF16 conversion
- Lost precision on the double round-trip
- Caused vLLM to hang — the NVFP4 attention kernel expects native NVFP4 weights, not BF16 weights with an NVFP4 quant_method attached
The fix: keep everything in NVFP4. The checkpoint stores NVFP4. The kernels consume NVFP4. No conversion needed.
Key Lessons
- A wrong reference is worse than no reference — the 0.2 cosine against a broken BF16 dequant sent us chasing SF remap bugs for weeks
- The C++ CUTLASS API is a footgun for FP4 — CuTeDSL handles tensor layouts, tiling, and SF construction correctly by construction
- Test with real data early — uniform tests pass even with broken kernels; random data reveals real bugs
- Separate the GEMM from the pipeline — our
layertest.pyruns without vLLM, Docker, or tensor parallelism. It caught the kernel bug that vLLM's integration layers masked. - Don't dequant what's already quantized — if the kernel expects NVFP4 and the checkpoint is NVFP4, leave it alone. No BF16 round-trips.
Project Structure
nvfp4-megamoe-kernel/
├── cutedsl/ # CuTeDSL kernel + bridge layer
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel (untouched)
│ ├── torch_scaled_grouped_mm.py # The working kernel (3900 lines)
│ ├── moe_utils.py
│ moe_persistent_scheduler.py
│ └── moe_sched_extension.py
├── vllm/ # vLLM integration
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner — MoE kernel interface
│ └── patches/
│ ├── deepseek_v4.py # DeepSeek-V4 model patch (NVFP4 native)
│ └── deepseek_v4_attention.py # Attention patch (NVFP4 native)
├── src/nvfp4_megamoe_kernel/ # OLD Python pipeline (tagged the-last-of-cutlass)
├── tests/
│ ├── layertest.py # Layer 0 comparison: CuTeDSL vs BF16 (✅ cosine 0.989)
│ ├── test_cutedsl.py # Small standalone CuTeDSL test (✅ cosine 0.991)
│ ├── test_uniform_fp4.py # Uniform data GEMM test
│ ├── test_b_layout.py # B matrix column layout test
│ └── test_quick_rand.py # Quick random GEMM sanity check
└── reference/ # Reference files for study
The Bridge Layer (cutedsl/bridge.py)
Handles all tensor layout conversion from our pipeline to what the CuTeDSL kernel expects:
| Function | What it does |
|---|---|
quantize_to_nvfp4() |
BF16 → float4_e2m1fn_x2 + float8_e4m3fn block scales + float32 global scale |
quantize_weight_to_nvfp4() |
Same, but for weight matrices with K as the packed dimension |
assemble_scales_2d_side() |
Pad and swizzle activation scale factors (2Dx3D A side) |
assemble_scales_3d_side() |
Pad and swizzle weight scale factors (2Dx3D B side) |
make_b_k_major() |
Convert B tensor from N-major to K-major strides (required by kernel) |
compute_expert_offsets() |
Compute cumulative token offsets for grouped GEMM |
run_nvfp4_grouped_gemm() |
Full kernel launch (compile + run) |
Running Tests
On the B200:
cd /root/nvfp4-megamoe-kernel/tests
source .venv/bin/activate
# Small standalone test
python3 test_cutedsl.py
# Full layer 0 comparison with real weights
python3 layertest.py
NVFP4 Coverage
| Component | Format | Kernel | Conversion? |
|---|---|---|---|
| MoE experts (L1+L2) | NVFP4 native | CuTeDSL ScaledGroupedGemm | No — direct uint8→float4 view-cast |
| Shared experts | NVFP4 native | FlashInferCutlassNvFp4 | No — stays native |
| wq_b, wo_b, fused_wqa_wkv | NVFP4 native | FlashInferCutlassNvFp4 | No — stays native |
| wo_a | NVFP4 → FP8 | fp8_einsum | Yes — fp8_einsum requires FP8 |
| Compressor | NVFP4 → BF16 | torch.mm | Yes — weight_loader stacking issue |
| KV cache | FP8 | FlashInfer MLA | N/A — FP8 is optimal for KV cache |
Plan
Phase 1: Kernel ✅ DONE
- CuTeDSL ScaledGroupedGemmKernel works with NVFP4
- Bridge layer handles all tensor layout conversion
- Full MoE pipeline (L1→SiLU→L2→scatter) produces cosine 0.989 vs BF16
Phase 2: vLLM Integration ✅ DONE
- CuTeDSLMoERunner wires CuTeDSL kernel into vLLM
- Weight loading: checkpoint uint8 → float4_e2m1fn_x2 view-cast (bit-preserving)
- Block scales (float8_e4m3fn) and global scales (float32) pass through directly
- L1 dual global scale handling: normalize to max(gate_gs, up_gs), fold ratio into block scales
- Attention projections stay native NVFP4 (FlashInferCutlassNvFp4LinearKernel)
- CuTeDSL kernel warmup during model load (prevents RPC timeout)
- Removed all debug prints and env var gates from vLLM serving path
Phase 3: Optimization
- Replace wo_a FP8 conversion with native NVFP4 GEMM (eliminate last dequant)
- Fix compressor weight_loader so it stays NVFP4 native
- Explore larger tile sizes for better occupancy
- Profile end-to-end inference on full model
Phase 4: Production
- Clean up old C++ kernel code (tagged
the-last-of-cutlass) - Add proper error handling and logging
- Benchmark vs BF16 baseline