First inference triggers Triton/TileLang kernel JIT compilation (2-3 min). The default 5-min RPC timeout kills the engine. Bumped to 10 min via VLLM_RPC_TIMEOUT_MS so the first request survives compilation. Not ideal — would prefer to warm up the kernels during startup. But CUDA graphs don't work well with grouped GEMMs and variable expert counts. Will investigate vLLM warmup shape config later.
NVFP4 MegaMoE Kernel
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, and attention projections — runs in native NVFP4 with zero dequantization overhead.
What This Is
A native NVFP4 inference stack for DeepSeek-V4:
MoE Experts — CuTeDSL ScaledGroupedGemmKernel (our work):
BF16 input → quantize to NVFP4
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
Re-quantize → NVFP4
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
Scatter with routing weights → BF16 output
Attention Projections — FlashInferCutlassNvFp4LinearKernel (vLLM built-in):
wq_b,wo_b,fused_wqa_wkv— native NVFP4, no conversionwo_a— NVFP4→FP8 forfp8_einsum(only attention weight that needs conversion)- Compressor — BF16 (weight_loader stacking issue, small matmul)
Shared Experts — FlashInferCutlassNvFp4LinearKernel (vLLM built-in):
gate_up_proj,down_proj— native NVFP4
Both GEMM types use float4_e2m1fn_x2 for weights, float8_e4m3fn for block scales, float32 for global scales. BF16 is used only for SiLU activation, the final MoE scatter, and the compressor — the minimum possible.
How We Got Here
The C++ CUTLASS Kernel Was Broken
The original kernel was a C++ .cu file using CUTLASS's C++ API directly. It passed all the simple tests (uniform data → exact output, SF remap verifier → 0 errors) but produced cosine 0.05 with real random data. After weeks of debugging the SF remap (8+ iterations, all producing the same 0.2 cosine against a wrong reference), we discovered:
-
The BF16 reference comparison was wrong — our Python dequantization didn't match CUTLASS's internal FP4 handling. A wrong reference is worse than no reference. We chased ghosts through 8+ SF remap rewrites because the 0.2 cosine was never about the remap.
-
The C++ CUTLASS kernel misinterpreted FP4 data — even with SF remap verified correct (0 byte errors), the GEMM produced garbage with non-uniform data. The issue was in how CUTLASS's C++ API handles FP4 packing/tiling internally — something we couldn't easily debug or fix.
-
The checkpoint
input_scalewas a red herring — we tried using the checkpoint's calibration scale as the activation normalization scale. It saturated all block scales to 448.0 (max float8). Theinput_scaleis a calibration constant for alpha computation, not a normalization scale.
The CuTeDSL Kernel Works
NVIDIA's CuTeDSL approach (Python-based CUTLASS kernels compiled via MLIR → PTX) is what the CUTLASS team recommends for Blackwell. Their official MoE scaled grouped GEMM example (torch_scaled_grouped_mm.py) supports NVFP4 out of the box. We adapted it.
Results with real DeepSeek-V4 layer 0 weights:
- L1 GEMM alone: cosine 0.995
- Full MoE pipeline (L1→SiLU→L2→scatter): cosine 0.989
- Weight loading: 0% loss — direct uint8→float4_e2m1fn_x2 view-cast, bit-identical to checkpoint
- Activation quantization: ~1.1% cosine loss (dynamic BF16→NVFP4 — inherent to the format, unavoidable)
- GEMM kernel: 0% loss (CuTeDSL is correct)
The 0.989 cosine is entirely from activation quantization. The weights are bit-identical to the checkpoint — no BF16 round-trip, no precision loss.
The Dequant→Requant Anti-Pattern
Early versions dequantized all NVFP4 weights to BF16, then let vLLM's FlashInferCutlassNvFp4LinearKernel requantize them back to NVFP4 at inference time. This:
- Wasted 5 minutes on load doing NVFP4→BF16 conversion
- Lost precision on the double round-trip
- Caused vLLM to hang — the NVFP4 attention kernel expects native NVFP4 weights, not BF16 weights with an NVFP4 quant_method attached
The fix: keep everything in NVFP4. The checkpoint stores NVFP4. The kernels consume NVFP4. No conversion needed.
Key Lessons
- A wrong reference is worse than no reference — the 0.2 cosine against a broken BF16 dequant sent us chasing SF remap bugs for weeks
- The C++ CUTLASS API is a footgun for FP4 — CuTeDSL handles tensor layouts, tiling, and SF construction correctly by construction
- Test with real data early — uniform tests pass even with broken kernels; random data reveals real bugs
- Separate the GEMM from the pipeline — our
layertest.pyruns without vLLM, Docker, or tensor parallelism. It caught the kernel bug that vLLM's integration layers masked. - Don't dequant what's already quantized — if the kernel expects NVFP4 and the checkpoint is NVFP4, leave it alone. No BF16 round-trips.
Project Structure
nvfp4-megamoe-kernel/
├── cutedsl/ # CuTeDSL kernel + bridge layer
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel (untouched)
│ ├── torch_scaled_grouped_mm.py # The working kernel (3900 lines)
│ ├── moe_utils.py
│ moe_persistent_scheduler.py
│ └── moe_sched_extension.py
├── vllm/ # vLLM integration
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner — MoE kernel interface
│ └── patches/
│ ├── deepseek_v4.py # DeepSeek-V4 model patch (NVFP4 native)
│ └── deepseek_v4_attention.py # Attention patch (NVFP4 native)
├── src/nvfp4_megamoe_kernel/ # OLD Python pipeline (tagged the-last-of-cutlass)
├── tests/
│ ├── layertest.py # Layer 0 comparison: CuTeDSL vs BF16 (✅ cosine 0.989)
│ ├── test_cutedsl.py # Small standalone CuTeDSL test (✅ cosine 0.991)
│ ├── test_uniform_fp4.py # Uniform data GEMM test
│ ├── test_b_layout.py # B matrix column layout test
│ └── test_quick_rand.py # Quick random GEMM sanity check
└── reference/ # Reference files for study
The Bridge Layer (cutedsl/bridge.py)
Handles all tensor layout conversion from our pipeline to what the CuTeDSL kernel expects:
| Function | What it does |
|---|---|
quantize_to_nvfp4() |
BF16 → float4_e2m1fn_x2 + float8_e4m3fn block scales + float32 global scale |
quantize_weight_to_nvfp4() |
Same, but for weight matrices with K as the packed dimension |
assemble_scales_2d_side() |
Pad and swizzle activation scale factors (2Dx3D A side) |
assemble_scales_3d_side() |
Pad and swizzle weight scale factors (2Dx3D B side) |
make_b_k_major() |
Convert B tensor from N-major to K-major strides (required by kernel) |
compute_expert_offsets() |
Compute cumulative token offsets for grouped GEMM |
run_nvfp4_grouped_gemm() |
Full kernel launch (compile + run) |
Running Tests
On the B200:
cd /root/nvfp4-megamoe-kernel/tests
source .venv/bin/activate
# Small standalone test
python3 test_cutedsl.py
# Full layer 0 comparison with real weights
python3 layertest.py
NVFP4 Coverage
| Component | Format | Kernel | Conversion? |
|---|---|---|---|
| MoE experts (L1+L2) | NVFP4 native | CuTeDSL ScaledGroupedGemm | No — direct uint8→float4 view-cast |
| Shared experts | NVFP4 native | FlashInferCutlassNvFp4 | No — stays native |
| wq_b, wo_b, fused_wqa_wkv | NVFP4 native | FlashInferCutlassNvFp4 | No — stays native |
| wo_a | NVFP4 → FP8 | fp8_einsum | Yes — fp8_einsum requires FP8 |
| Compressor | NVFP4 → BF16 | torch.mm | Yes — weight_loader stacking issue |
| KV cache | FP8 | FlashInfer MLA | N/A — FP8 is optimal for KV cache |
Plan
Phase 1: Kernel ✅ DONE
- CuTeDSL ScaledGroupedGemmKernel works with NVFP4
- Bridge layer handles all tensor layout conversion
- Full MoE pipeline (L1→SiLU→L2→scatter) produces cosine 0.989 vs BF16
Phase 2: vLLM Integration ✅ DONE
- CuTeDSLMoERunner wires CuTeDSL kernel into vLLM
- Weight loading: checkpoint uint8 → float4_e2m1fn_x2 view-cast (bit-preserving)
- Block scales (float8_e4m3fn) and global scales (float32) pass through directly
- L1 dual global scale handling: normalize to max(gate_gs, up_gs), fold ratio into block scales
- Attention projections stay native NVFP4 (FlashInferCutlassNvFp4LinearKernel)
- CuTeDSL kernel warmup during model load (prevents RPC timeout)
- Removed all debug prints and env var gates from vLLM serving path
Phase 3: Optimization
- Replace wo_a FP8 conversion with native NVFP4 GEMM (eliminate last dequant)
- Fix compressor weight_loader so it stays NVFP4 native
- Explore larger tile sizes for better occupancy
- Profile end-to-end inference on full model
Phase 4: Production
- Clean up old C++ kernel code (tagged
the-last-of-cutlass) - Add proper error handling and logging
- Benchmark vs BF16 baseline