# nvfp4-megamoe-kernel Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100). Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUTLASS-based implementation that emits real `SM100_MMA_MXF4_SS` tensor core instructions. --- ## Architecture DeepSeek-V4-Pro is a 256-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 32 experts. For each token, the router picks the top-6 experts. ### The MoE Forward Pass ``` Input hidden states (BF16) │ ▼ ┌─────────────────┐ │ Shared Experts │ ← BYPASSED (returning zeros — FlashInfer TF32 GEMM crashes) │ (FlashInfer │ │ CUTLASS) │ └─────────────────┘ │ ▼ Staging Kernel (vLLM built-in) BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32) Writes to SymmBuffer.x / SymmBuffer.x_sf │ ▼ Router (vLLM built-in) Writes topk_ids / topk_weights to SymmBuffer │ ▼ ┌─────────────────────────────────────────┐ │ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py │ │ │ 1. Read staged activation from buffer │ │ 2. L1 GEMM: gate_up_proj │ ← CUTLASS NVFP4 block-scaled │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 output (6144-wide) │ │ 3. SiLU(gate) * up (activation) │ │ 4. stage_activation: BF16 → FP4 │ ← simple absmax quantize (needs work) │ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 output (7168-wide) │ │ 6. Write to output tensor │ └─────────────────────────────────────────┘ ``` ### vLLM Startup Sequence (how our code plugs in) ``` 1. vLLM engine init └─ ModelOptNvFp4Config selected (NVFP4 quantization scheme) └─ FlashInferCutlassNvFp4LinearKernel for linear layers 2. Model construction └─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer Each layer has: attention + MoE block MoE block has: shared experts + 256 routed experts 3. Weight loading └─ 95 safetensor shards loaded └─ weight, weight_scale, weight_scale_2 loaded per linear 4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN └─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS └─ finalize_mega_moe_weights() └─ weight_transform.py: transform_nvfp4_weights_for_mega_moe() • Folds weight_scale_2 (global scale) into weight_scale (block scale) • UE4M3 block-16 scales: 4 values packed per uint32 • Interleaves L1 (gate_up) weights for 2CTA UMMA • Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank 5. SymmBuffer allocation └─ symm_buffer.py: get_symm_buffer_for_nvfp4 mega_moe() • Pre-allocates GPU buffers for: - x: int8 packed E2M1 activations - x_sf: uint32 packed UE4M3 activation scales - topk_idx: int32 expert indices - topk_weights: float32 routing weights - buffer: BF16 all-reduce buffer 6. Profile run (warmup) └─ First forward pass to allocate KV cache, etc. └─ This is where the CUTLASS GEMM first executes 7. Ready to serve ``` --- ## File Map ``` nvfp4_megamoe_kernel/ ├── __init__.py # Public API exports ├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, nvfp4_mega_moe_l1/l2, stage_activation ├── weight_transform.py # Weight prep: fold global scale, pack UE4M3, interleave L1 ├── symm_buffer.py # GPU buffer allocation for MoE dispatch │ └── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel) ├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel ├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward) ├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop) ├── sf_layout.py # CUTLASS SF interleaved layout math ├── setup.py # Build config (nvcc, CUTLASS include paths) ├── build.sh # Build script ├── test_gemm.py # Standalone test └── README.md ``` ### What each file does (in call order) | File | When it runs | What it does | |------|-------------|--------------| | `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales, packs UE4M3 into uint32, interleaves L1 gate_up weights. Output: `((l1_w, l1_sf), (l2_w, l2_sf))` | | `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. | | `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains `stage_activation` (BF16→FP4 quantize for L1→L2). | | `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Per-expert loop: gather tokens for each expert, call CUTLASS GEMM, scatter results with routing weights. | | `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side scale factor remap (row-major → CUTLASS interleaved layout). | | `cutlass_nvfp4_gemm/sf_layout.py` | Build time / reference | Documents the CUTLASS SfAtom layout. Currently unused at runtime (remap is in CUDA). | --- ## Data Formats ### Weights - **Packed E2M1** (`int8`): 2 FP4 values per byte. Shape: `(E_per_rank, N, K//2)`, K-major layout. - **UE4M3 block scales** (`float8_e4m3fn`): 1 scale per 16 FP4 values (group_size=16). Shape: `(E_per_rank, N, K//16)`. ### Activations (after staging kernel) - **Packed E2M1** (`int8`): Shape: `(num_tokens, K//2)`. - **UE4M3 scales** (`uint32`): 4 UE4M3 values packed per uint32. Shape: `(num_tokens, K//64)`. ### GEMM dimensions (DeepSeek-V4-Pro) - **L1 (gate_up_proj):** M×6144×7168 (per expert) - **L2 (down_proj):** M×7168×3072 (per expert) - 48 experts per rank (256 total / 8 ranks), top-6 routing --- ## Build & Deploy (B200) ```bash # On B200 host — CUTLASS must be cloned and mounted cd /root/nvidia-meeting/deepseek-v4-quant/ # Rebuild container (CUTLASS is host-mounted at /root/cutlass) KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache docker compose up -d ``` The CUTLASS extension builds inside the container during `pip install` of the nvfp4-megamoe-kernel package. It needs: - CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image) - CUTLASS headers at `/root/cutlass/include/` - CCCL headers at `/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/` - Device with SM100 compute capability (B200) --- ## Known Issues 1. **Shared experts bypassed** — FlashInfer/DeepGEMM TF32 GEMM crashes the vLLM worker. Currently returning zeros for shared expert output. This produces garbage text. 2. **MoE dispatch is slow** — `cutlass_grouped_nvfp4_gemm` uses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch. 3. **stage_activation is approximate** — Simple per-token absmax quantization for L1→L2 re-quant. Should use proper E2M1 quantization matching vLLM's staging kernel. 4. **Scale factor remap adds overhead** — GPU kernel remaps row-major → CUTLASS interleaved layout every GEMM call. Should pre-compute during weight transform. --- ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| | `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) | | `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging | | `MEGA_MOE_USE_CUTLASS` | 1 | Use CUTLASS path (always 1 now, TileLang removed) | | `SKIP_ATTENTION` | 0 | Skip attention layers (debug) |