biondizzle 3445bd24c1 feat: keep attention weights native NVFP4 — stop dequantizing to BF16
_convert_nvfp4_post_load() was converting wq_b, wo_b, fused_wqa_wkv
from NVFP4→BF16. These layers already have FlashInferCutlassNvFp4LinearKernel
registered as their quant_method — they CAN run native NVFP4.

Now only wo_a gets FP8 conversion (fp8_einsum requires FP8) and
compressor gets BF16 reconstruction (weight_loader issue).
Everything else stays NVFP4 native — Blackwell FP4 acceleration
for the full model, not just the MoE experts.

This also eliminates the 5-minute NVFP4→BF16 conversion loop.
2026-05-16 05:36:34 +00:00

NVFP4 MegaMoE Kernel

NVFP4 block-scaled Mixture-of-Experts kernel for DeepSeek-V4 on NVIDIA Blackwell (SM100). Uses CuTeDSL — NVIDIA's Python-based CUTLASS DSL — for a native NVFP4 pipeline that takes full advantage of Blackwell's TMA, MMA, and epilogue overlap.

What This Is

A fused MoE FFN kernel that runs the entire expert forward pass in NVFP4:

BF16 input → quantize to NVFP4
  L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
  SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
  Re-quantize → NVFP4
  L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
  Scatter with routing weights → BF16 output

Both GEMMs are fully NVFP4: A and B in float4_e2m1fn_x2, block scales in float8_e4m3fn, global scales in float32. BF16 is used only for the SiLU activation and the final scatter — the minimum possible.

How We Got Here

The C++ CUTLASS Kernel Was Broken

The original kernel was a C++ .cu file using CUTLASS's C++ API directly. It passed all the simple tests (uniform data → exact output, SF remap verifier → 0 errors) but produced cosine 0.05 with real random data. After weeks of debugging the SF remap (8+ iterations, all producing the same 0.2 cosine against a wrong reference), we discovered:

  1. The BF16 reference comparison was wrong — our Python dequantization didn't match CUTLASS's internal FP4 handling. A wrong reference is worse than no reference. We chased ghosts through 8+ SF remap rewrites because the 0.2 cosine was never about the remap.

  2. The C++ CUTLASS kernel misinterpreted FP4 data — even with SF remap verified correct (0 byte errors), the GEMM produced garbage with non-uniform data. The issue was in how CUTLASS's C++ API handles FP4 packing/tiling internally — something we couldn't easily debug or fix.

  3. The checkpoint input_scale was a red herring — we tried using the checkpoint's calibration scale as the activation normalization scale. It saturated all block scales to 448.0 (max float8). The input_scale is a calibration constant for alpha computation, not a normalization scale.

The CuTeDSL Kernel Works

NVIDIA's CuTeDSL approach (Python-based CUTLASS kernels compiled via MLIR → PTX) is what the CUTLASS team recommends for Blackwell. Their official MoE scaled grouped GEMM example (torch_scaled_grouped_mm.py) supports NVFP4 out of the box. We adapted it.

Results with real DeepSeek-V4 layer 0 weights:

  • L1 GEMM alone: cosine 0.995
  • Full MoE pipeline (L1→SiLU→L2→scatter): cosine 0.989
  • Weight loading: 0% loss — direct uint8→float4_e2m1fn_x2 view-cast, bit-identical to checkpoint
  • Activation quantization: ~1.1% cosine loss (dynamic BF16→NVFP4 — inherent to the format, unavoidable)
  • GEMM kernel: 0% loss (CuTeDSL is correct)

The 0.989 cosine is entirely from activation quantization. The weights are bit-identical to the checkpoint — no BF16 round-trip, no precision loss.

Key Lessons

  1. A wrong reference is worse than no reference — the 0.2 cosine against a broken BF16 dequant sent us chasing SF remap bugs for weeks
  2. The C++ CUTLASS API is a footgun for FP4 — CuTeDSL handles tensor layouts, tiling, and SF construction correctly by construction
  3. Test with real data early — uniform tests pass even with broken kernels; random data reveals real bugs
  4. Separate the GEMM from the pipeline — our layertest.py runs without vLLM, Docker, or tensor parallelism. It caught the kernel bug that vLLM's integration layers masked.

Project Structure

nvfp4-megamoe-kernel/
├── cutedsl/                          # CuTeDSL kernel + bridge layer
│   ├── bridge.py                     # Tensor layout conversion, quantization, kernel launch
│   ├── moe_pipeline.py              # Full MoE pipeline (L1→SiLU→L2→scatter)
│   └── kernel/moe/                   # NVIDIA's ScaledGroupedGemmKernel (untouched)
│       ├── torch_scaled_grouped_mm.py   # The working kernel (3900 lines)
│       ├── moe_utils.py
│       ├── moe_persistent_scheduler.py
│       └── moe_sched_extension.py
├── src/nvfp4_megamoe_kernel/         # OLD Python pipeline (being replaced)
│   ├── nvfp4_mega_moe.py            # Old pipeline — calls broken C++ kernel
│   └── cutlass_nvfp4_gemm/          # OLD C++ CUTLASS extension (BROKEN)
├── vllm/                             # vLLM integration
│   └── patches/
│       └── deepseek_v4.py           # DeepSeek-V4 model patch
├── tests/
│   ├── layertest.py                 # Layer 0 comparison: CuTeDSL vs BF16 (✅ cosine 0.989)
│   ├── test_cutedsl.py              # Small standalone CuTeDSL test (✅ cosine 0.991)
│   ├── test_uniform_fp4.py          # Uniform data GEMM test
│   ├── test_b_layout.py             # B matrix column layout test
│   └── test_quick_rand.py           # Quick random GEMM sanity check
├── reference/                        # Reference files for study
└── REWRITE_PLAN.md                  # Original rewrite plan

The Bridge Layer (cutedsl/bridge.py)

Handles all tensor layout conversion from our pipeline to what the CuTeDSL kernel expects:

Function What it does
quantize_to_nvfp4() BF16 → float4_e2m1fn_x2 + float8_e4m3fn block scales + float32 global scale
quantize_weight_to_nvfp4() Same, but for weight matrices with K as the packed dimension
assemble_scales_2d_side() Pad and swizzle activation scale factors (2Dx3D A side)
assemble_scales_3d_side() Pad and swizzle weight scale factors (2Dx3D B side)
make_b_k_major() Convert B tensor from N-major to K-major strides (required by kernel)
compute_expert_offsets() Compute cumulative token offsets for grouped GEMM
run_nvfp4_grouped_gemm() Full kernel launch (compile + run)

Running Tests

On the B200:

cd /root/nvfp4-megamoe-kernel/tests
source .venv/bin/activate

# Small standalone test
python3 test_cutedsl.py

# Full layer 0 comparison with real weights
python3 layertest.py

Plan

Phase 1: Kernel DONE

  • CuTeDSL ScaledGroupedGemmKernel works with NVFP4
  • Bridge layer handles all tensor layout conversion
  • Full MoE pipeline (L1→SiLU→L2→scatter) produces cosine 0.989 vs BF16

Phase 2: vLLM Integration (IN PROGRESS)

  • Wire cutedsl/moe_pipeline.py into the vLLM DeepSeek-V4 model
  • Replace nvfp4_mega_moe_full() call with CuTeDSLMoERunner.run()
  • Weight loading: checkpoint uint8 → float4_e2m1fn_x2 view-cast (bit-preserving, no BF16 round-trip)
  • Block scales (float8_e4m3fn) and global scales (float32) pass through directly from checkpoint
  • L1 dual global scale handling: normalize to max(gate_gs, up_gs), fold ratio into block scales
  • Remove C++ CUTLASS extension build from Dockerfile
  • Add CuTeDSL dependency to the Docker build

Phase 3: Optimization

  • Explore larger tile sizes for better occupancy
  • Profile end-to-end inference on full model

Phase 4: Production

  • Clean up debug artifacts
  • Remove old C++ kernel code
  • Add proper error handling and logging
  • Benchmark vs BF16 baseline
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%