Files
nvfp4-megamoe-kernel/README.md
biondizzle 57d4cb714f docs: rewrite README.md with current project state
- Document all 5 correctness bug fixes
- Document fused SwiGLU epilogue progress (Step 1 PASS, Step 2 blocked)
- Document CuTeDSL runtime conditional limitation
- List remaining steps (amax shuffles, NVFP4 quantize, FP4/SF TMA stores)
- Document weight interleave and register layout
- Capture key lessons learned
- Update file structure and test inventory
2026-05-20 03:30:35 +00:00

9.7 KiB
Raw Blame History

NVFP4 MegaMoE Kernel

Native NVFP4 inference stack for DeepSeek-V4 on NVIDIA Blackwell (SM100). CuTeDSL kernels for the entire model — MoE experts, shared experts, attention projections — running in native NVFP4 with zero dequantization overhead.

⚠️ THE #1 RULE

WE OWN ALL OUR KERNELS. WE DO NOT PATCH vLLM.

vLLM's internal kernels (FlashMLA, fp8_ds_mla, fused compressor, Triton indexer) are deeply coupled. You cannot swap one piece and expect the rest to work. We build our own CuTeDSL kernels, test standalone, then wire into vLLM as an attention backend.


Repository Layout

This repo (nvfp4-megamoe-kernel): The kernel library — CuTeDSL kernels, bridge layer, standalone tests.

vLLM fork (vllm-deepseekv4-nvfp4): The vLLM integration — model definition, weight loading, attention backend. Lives at /root/dsv4-nvfp4-workspace/vllm on the B200.

Workspace (/root/dsv4-nvfp4-workspace):

  • kernel/ — clone of this repo
  • vllm/ — clone of the vLLM fork
  • FUSED_EPILOGUE_PLAN.md — fused SwiGLU epilogue plan
  • FUSED_EPILOGUE_STATUS.md — current status

What We Have

CuTeDSL NVFP4 Grouped GEMM (the building block)

ScaledGroupedGemmKernel in cutedsl/kernel/moe/torch_scaled_grouped_mm.py — a production-grade NVFP4 grouped GEMM kernel:

  • 2D×3D scenario: A(M,K) × B(E,K,N) → C(M,N)
  • Block-scaled: per-16-element FP8 scales on both A and B sides
  • Global scales (per-expert) for full dynamic range
  • Persistent scheduler, TMA pipelining, SMEM swizzle
  • CUDAGraph-safe (workspace pre-allocated, no runtime allocations)

Bridge Layer (cutedsl/bridge.py)

  • quantize_to_nvfp4() — BF16 → NVFP4 with global scale
  • quantize_activation_nvfp4() — cudagraph-safe quantize (pre-computed gs)
  • quantize_weight_to_nvfp4() — weight quantization (along K dim)
  • interleave_l1_weights() — gate/up interleave at granularity 8 BF16
  • make_b_k_major() — B tensor stride conversion
  • assemble_scales_2d_side() / assemble_scales_3d_side() — scale assembly + swizzle
  • warmup_compilation() — eager JIT compilation before first forward pass
  • run_nvfp4_grouped_gemm() — the main entry point

MoE Runner (cutedsl/runner.py)

CuTeDSLMoERunner — runs the MoE forward pass:

  1. Quantize input BF16 → NVFP4 (using pre-computed gs)
  2. L1 GEMM: NVFP4 × NVFP4 → BF16 (gate+up fused)
  3. SiLU(gate) * up → BF16 (PyTorch, not yet fused)
  4. Re-quantize BF16 → NVFP4
  5. L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
  6. Scatter with routing weights

NVFP4 Linear (cutedsl/nvfp4_linear.py)

CuTeDSLNvfp4Linear — single-expert NVFP4 GEMM for shared experts and attention projections.

Fused SwiGLU Kernel (in progress)

fused_swiglu_grouped_mm.py — extends ScaledGroupedGemmKernel with a fused SiLU epilogue:

  • Step 1 DONE: SiLU in registers validated (0.034% error vs PyTorch)
  • Step 2 BLOCKED: Gate/up pairing blocked by CuTeDSL type system (see below)

Correctness Bugs Fixed (May 20, 2026)

All 5 bugs fixed, committed, pushed:

Bug Issue Fix
1 _needs_token_refill myth — cute.compile doesn't corrupt GPU memory Removed hack, added warmup_compilation(), pre-allocated workspace per cache entry
2 Dequantize→requantize supposedly lossy Verified 100% byte-identical round-trip. Deprecated prepare_weights_from_dequantized
3 clamp(min=1e-8) on zero blocks gives nonzero FP8 scale Detect zero blocks, force FP8 scale to exact 0
4 Underflow blocks (amax < 6×2⁻⁹) get nonzero FP4 from div-by-tiny-number Detect underflow blocks, zero x_norm before division
5 Expert counting materializes 18M bool tensor torch.bincount replaces O(n×E) comparison

Fused SwiGLU Epilogue — Current State

The Goal

Fuse SiLU(gate)*up + NVFP4 quantization into the L1 GEMM epilogue. This eliminates:

  • ~580MB BF16 write to GMEM
  • ~290MB BF16 read back
  • 3 kernel launches + 12 quantize ops
  • Expected: ~30-40% latency reduction for the MoE block

Step 1: SiLU in Registers — VALIDATED

cute.exp and element-wise FP32 ops work correctly on CuTe register tensors in the epilogue. SiLU(x) = x / (1+exp(-x)) produces 0.034% relative error vs PyTorch.

Step 2: Gate/Up Pairing — BLOCKED BY CUTEDSL TYPE SYSTEM

The problem: CuTeDSL compiles ALL subtile iterations into one kernel. Runtime conditionals (if is_gate_subtile) that affect:

  • Register tensor assignment → DSLRuntimeError (type structure mismatch)
  • TMA store skipping → corrupted output
  • Mask blending on register tensors → wrong results

CuTeDSL requires that ALL code paths produce tensors with the same structure. Even though both branches produce the same tensor type, the compiler can't unify them when the branch condition is a runtime value.

What's Needed for Step 2

Option A: Paired subtile iteration. Instead of iterating subtiles [0,1,2,3] and branching on each, iterate as gate/up pairs [(0,2), (1,3)]. For each pair, load both gate and up accumulator, compute SiLU(gate)*up, store result. No runtime conditionals — every iteration does the same thing. Requires restructuring the epilogue loop.

Option B: const_expr debug flag. Compile a separate kernel with debug_silu_bf16=True that writes post-SiLU BF16 to a (M, intermediate) side tensor. Validate, then add NVFP4 quantize + FP4/SF TMA stores. The production kernel (flag=False) skips the BF16 write.

Option C: Separate post-GEMM SiLU kernel. A small CUDA kernel that reads BF16 L1 output, applies SiLU(gate)*up, writes result. Adds one kernel launch but avoids the CuTeDSL type system constraint entirely.

Remaining Steps (after gate/up pairing)

Step What Status
3 Per-16-element amax via warp shuffles Not started
4 FP8 E4M3 scale + E2M1 round + nibble pack Not started
5 FP4 TMA store to padded L2 buffer Not started
6 FP8 SF TMA store through blockscaled layout Not started

Weight Interleave

Gate/up weights must be interleaved at granularity 8 BF16 (4 FP4) for the fused epilogue. interleave_l1_weights() in bridge.py implements this. Pure-PyTorch invariant test passes. Kernel-level test blocked by the same subtile iteration issue.

Register Layout (from DeepGEMM)

After SM100_TMEM_LOAD_16dp256b1x, register fragment has gate/up paired:

  • (values[0], values[2]), (values[1], values[3])
  • (values[4], values[6]), (values[5], values[7])

Our CuTeDSL kernel uses tiled_copy_r2s.retile() which may produce a different register layout. Need to verify against the debug BF16 output.


DeepSeek-V4 Architecture Notes

NOT MLA. DeepSeek-V4 uses:

  • CSA (Compressed Sparse Attention, cr=4): KV compressed 4x, indexer finds top-k
  • HCA (Heavily Compressed Attention, cr=128): KV compressed 128x, pre-computed indices
  • SWA: Standard sliding window (window=128, last layer only)
  • mHC: Manifold-Constrained Hyper-Connections — replaces residual connections
  • 384 experts, top-6, intermediate=3072

Compress ratios by layer: alternating 128/4, layer 60 = 0 (SWA).


File Structure

cutedsl/
├── bridge.py                          # Quantization, layout, kernel launch
├── nvfp4_linear.py                    # Single-expert NVFP4 GEMM runner
├── runner.py                          # MoE grouped GEMM runner
├── blackwell_attention.py             # KV cache + attention (standalone)
├── csa_attention.py                   # CSA/HCA attention
├── custom_ops.py                      # torch.autograd wrappers
├── moe_pipeline.py                    # Standalone test pipeline (deprecated path)
└── kernel/moe/
    ├── torch_scaled_grouped_mm.py     # ScaledGroupedGemmKernel (the GEMM)
    └── fused_swiglu_grouped_mm.py     # FusedSwiGLUScaledGroupedGemmKernel (WiP)

tests/
├── test_fused_step1.py               # SiLU validation (PASS)
├── test_fp4_roundtrip.py             # Checkpoint byte match (PASS)
├── test_interleave_gemm.py           # Weight interleave GEMM test (BLOCKED)
├── layertest.py                      # MoE layer test (PASS, 0.988 cosine)
├── cudagraph_test.py                  # CUDAGraph test (PASS)
├── test_full_layer_b200.py           # All NVFP4 projections (PASS, 0.994+)
├── test_v4_attention_b200.py         # All 3 attention types (PASS)
├── test_kv_cache_b200.py             # KV cache (PASS, 0.9997)
├── test_sparse_attn_b200.py          # CSA/HCA (PASS)
├── test_decode_attention_b200.py     # Prefill+decode (PASS, 0.9998)
└── ...

Key Lessons (Things We Fucked Up)

  1. NEVER assume CuTeDSL GPU tensors survive JIT compilation. cute.compile zeroes GPU memory. Keep index/mapping tensors on CPU. Always verify with .cpu().tolist() after JIT.

  2. NEVER nuke working code without understanding why it exists. The cudagraph-safe functions exist because vLLM REQUIRES cudagraph.

  3. NEVER fabricate facts from MEMORY.md. Verify what "works" means before citing it.

  4. NEVER quantize a padded buffer and slice the output. Quantize compact data, scatter into padded layout.

  5. Silent weight drops are deadly. vLLM's if name not in params_dict: continue skips weights with no warning. Replace with hard RuntimeError.

  6. NVFP4 is NOT suitable for attention Q×K^T. Per-element dot products are too sensitive. Keep attention in BF16.

  7. NEVER touch drivers, kernels, firmware, or system packages on the B200. The cluster costs millions. Always confirm with Mike.

  8. CuTeDSL runtime conditionals on register tensors are broken. Can't branch on runtime values when the branch affects tensor structure. Use const_expr flags or restructure the loop.