Files
nvfp4-megamoe-kernel/README.md
biondizzle e5370140cb docs: update README with full NVFP4 coverage, dequant anti-pattern, v2 status
- Added NVFP4 coverage table (what's native, what's converted, why)
- Documented the dequant→requant anti-pattern that caused vLLM hangs
- Updated plan: Phase 2 done, Phase 3 targets remaining conversions
- Removed stale REWRITE_PLAN reference
- Updated project structure (nvfp4_cutedsl.py, removed old refs)
2026-05-16 05:43:33 +00:00

8.9 KiB
Raw Blame History

NVFP4 MegaMoE Kernel

Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, and attention projections — runs in native NVFP4 with zero dequantization overhead.

What This Is

A native NVFP4 inference stack for DeepSeek-V4:

MoE Experts — CuTeDSL ScaledGroupedGemmKernel (our work):

BF16 input → quantize to NVFP4
  L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
  SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
  Re-quantize → NVFP4
  L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
  Scatter with routing weights → BF16 output

Attention Projections — FlashInferCutlassNvFp4LinearKernel (vLLM built-in):

  • wq_b, wo_b, fused_wqa_wkv — native NVFP4, no conversion
  • wo_a — NVFP4→FP8 for fp8_einsum (only attention weight that needs conversion)
  • Compressor — BF16 (weight_loader stacking issue, small matmul)

Shared Experts — FlashInferCutlassNvFp4LinearKernel (vLLM built-in):

  • gate_up_proj, down_proj — native NVFP4

Both GEMM types use float4_e2m1fn_x2 for weights, float8_e4m3fn for block scales, float32 for global scales. BF16 is used only for SiLU activation, the final MoE scatter, and the compressor — the minimum possible.

How We Got Here

The C++ CUTLASS Kernel Was Broken

The original kernel was a C++ .cu file using CUTLASS's C++ API directly. It passed all the simple tests (uniform data → exact output, SF remap verifier → 0 errors) but produced cosine 0.05 with real random data. After weeks of debugging the SF remap (8+ iterations, all producing the same 0.2 cosine against a wrong reference), we discovered:

  1. The BF16 reference comparison was wrong — our Python dequantization didn't match CUTLASS's internal FP4 handling. A wrong reference is worse than no reference. We chased ghosts through 8+ SF remap rewrites because the 0.2 cosine was never about the remap.

  2. The C++ CUTLASS kernel misinterpreted FP4 data — even with SF remap verified correct (0 byte errors), the GEMM produced garbage with non-uniform data. The issue was in how CUTLASS's C++ API handles FP4 packing/tiling internally — something we couldn't easily debug or fix.

  3. The checkpoint input_scale was a red herring — we tried using the checkpoint's calibration scale as the activation normalization scale. It saturated all block scales to 448.0 (max float8). The input_scale is a calibration constant for alpha computation, not a normalization scale.

The CuTeDSL Kernel Works

NVIDIA's CuTeDSL approach (Python-based CUTLASS kernels compiled via MLIR → PTX) is what the CUTLASS team recommends for Blackwell. Their official MoE scaled grouped GEMM example (torch_scaled_grouped_mm.py) supports NVFP4 out of the box. We adapted it.

Results with real DeepSeek-V4 layer 0 weights:

  • L1 GEMM alone: cosine 0.995
  • Full MoE pipeline (L1→SiLU→L2→scatter): cosine 0.989
  • Weight loading: 0% loss — direct uint8→float4_e2m1fn_x2 view-cast, bit-identical to checkpoint
  • Activation quantization: ~1.1% cosine loss (dynamic BF16→NVFP4 — inherent to the format, unavoidable)
  • GEMM kernel: 0% loss (CuTeDSL is correct)

The 0.989 cosine is entirely from activation quantization. The weights are bit-identical to the checkpoint — no BF16 round-trip, no precision loss.

The Dequant→Requant Anti-Pattern

Early versions dequantized all NVFP4 weights to BF16, then let vLLM's FlashInferCutlassNvFp4LinearKernel requantize them back to NVFP4 at inference time. This:

  • Wasted 5 minutes on load doing NVFP4→BF16 conversion
  • Lost precision on the double round-trip
  • Caused vLLM to hang — the NVFP4 attention kernel expects native NVFP4 weights, not BF16 weights with an NVFP4 quant_method attached

The fix: keep everything in NVFP4. The checkpoint stores NVFP4. The kernels consume NVFP4. No conversion needed.

Key Lessons

  1. A wrong reference is worse than no reference — the 0.2 cosine against a broken BF16 dequant sent us chasing SF remap bugs for weeks
  2. The C++ CUTLASS API is a footgun for FP4 — CuTeDSL handles tensor layouts, tiling, and SF construction correctly by construction
  3. Test with real data early — uniform tests pass even with broken kernels; random data reveals real bugs
  4. Separate the GEMM from the pipeline — our layertest.py runs without vLLM, Docker, or tensor parallelism. It caught the kernel bug that vLLM's integration layers masked.
  5. Don't dequant what's already quantized — if the kernel expects NVFP4 and the checkpoint is NVFP4, leave it alone. No BF16 round-trips.

Project Structure

nvfp4-megamoe-kernel/
├── cutedsl/                          # CuTeDSL kernel + bridge layer
│   ├── bridge.py                     # Tensor layout conversion, quantization, kernel launch
│   ├── moe_pipeline.py              # Full MoE pipeline (L1→SiLU→L2→scatter)
│   └── kernel/moe/                   # NVIDIA's ScaledGroupedGemmKernel (untouched)
│       ├── torch_scaled_grouped_mm.py   # The working kernel (3900 lines)
│       ├── moe_utils.py
│       moe_persistent_scheduler.py
│       └── moe_sched_extension.py
├── vllm/                             # vLLM integration
│   ├── nvfp4_cutedsl.py             # CuTeDSLMoERunner — MoE kernel interface
│   └── patches/
│       ├── deepseek_v4.py           # DeepSeek-V4 model patch (NVFP4 native)
│       └── deepseek_v4_attention.py # Attention patch (NVFP4 native)
├── src/nvfp4_megamoe_kernel/         # OLD Python pipeline (tagged the-last-of-cutlass)
├── tests/
│   ├── layertest.py                 # Layer 0 comparison: CuTeDSL vs BF16 (✅ cosine 0.989)
│   ├── test_cutedsl.py              # Small standalone CuTeDSL test (✅ cosine 0.991)
│   ├── test_uniform_fp4.py          # Uniform data GEMM test
│   ├── test_b_layout.py             # B matrix column layout test
│   └── test_quick_rand.py           # Quick random GEMM sanity check
└── reference/                        # Reference files for study

The Bridge Layer (cutedsl/bridge.py)

Handles all tensor layout conversion from our pipeline to what the CuTeDSL kernel expects:

Function What it does
quantize_to_nvfp4() BF16 → float4_e2m1fn_x2 + float8_e4m3fn block scales + float32 global scale
quantize_weight_to_nvfp4() Same, but for weight matrices with K as the packed dimension
assemble_scales_2d_side() Pad and swizzle activation scale factors (2Dx3D A side)
assemble_scales_3d_side() Pad and swizzle weight scale factors (2Dx3D B side)
make_b_k_major() Convert B tensor from N-major to K-major strides (required by kernel)
compute_expert_offsets() Compute cumulative token offsets for grouped GEMM
run_nvfp4_grouped_gemm() Full kernel launch (compile + run)

Running Tests

On the B200:

cd /root/nvfp4-megamoe-kernel/tests
source .venv/bin/activate

# Small standalone test
python3 test_cutedsl.py

# Full layer 0 comparison with real weights
python3 layertest.py

NVFP4 Coverage

Component Format Kernel Conversion?
MoE experts (L1+L2) NVFP4 native CuTeDSL ScaledGroupedGemm No — direct uint8→float4 view-cast
Shared experts NVFP4 native FlashInferCutlassNvFp4 No — stays native
wq_b, wo_b, fused_wqa_wkv NVFP4 native FlashInferCutlassNvFp4 No — stays native
wo_a NVFP4 → FP8 fp8_einsum Yes — fp8_einsum requires FP8
Compressor NVFP4 → BF16 torch.mm Yes — weight_loader stacking issue
KV cache FP8 FlashInfer MLA N/A — FP8 is optimal for KV cache

Plan

Phase 1: Kernel DONE

  • CuTeDSL ScaledGroupedGemmKernel works with NVFP4
  • Bridge layer handles all tensor layout conversion
  • Full MoE pipeline (L1→SiLU→L2→scatter) produces cosine 0.989 vs BF16

Phase 2: vLLM Integration DONE

  • CuTeDSLMoERunner wires CuTeDSL kernel into vLLM
  • Weight loading: checkpoint uint8 → float4_e2m1fn_x2 view-cast (bit-preserving)
  • Block scales (float8_e4m3fn) and global scales (float32) pass through directly
  • L1 dual global scale handling: normalize to max(gate_gs, up_gs), fold ratio into block scales
  • Attention projections stay native NVFP4 (FlashInferCutlassNvFp4LinearKernel)
  • CuTeDSL kernel warmup during model load (prevents RPC timeout)
  • Removed all debug prints and env var gates from vLLM serving path

Phase 3: Optimization

  • Replace wo_a FP8 conversion with native NVFP4 GEMM (eliminate last dequant)
  • Fix compressor weight_loader so it stays NVFP4 native
  • Explore larger tile sizes for better occupancy
  • Profile end-to-end inference on full model

Phase 4: Production

  • Clean up old C++ kernel code (tagged the-last-of-cutlass)
  • Add proper error handling and logging
  • Benchmark vs BF16 baseline