2026-05-14 12:48:08 +00:00
# nvfp4-megamoe-kernel
2026-05-13 15:44:51 +00:00
2026-05-14 12:48:08 +00:00
Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).
2026-05-13 15:44:51 +00:00
2026-05-14 12:48:08 +00:00
Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUTLASS-based implementation that emits real `SM100_MMA_MXF4_SS` tensor core instructions.
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
---
2026-05-13 15:44:51 +00:00
## Architecture
2026-05-14 12:48:08 +00:00
DeepSeek-V4-Pro is a 256-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 32 experts. For each token, the router picks the top-6 experts.
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
### The MoE Forward Pass
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
```
Input hidden states (BF16)
│
▼
┌─────────────────┐
│ Shared Experts │ ← BYPASSED (returning zeros — FlashInfer TF32 GEMM crashes)
│ (FlashInfer │
│ CUTLASS) │
└─────────────────┘
│
▼
Staging Kernel (vLLM built-in)
BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
Writes to SymmBuffer.x / SymmBuffer.x_sf
│
▼
Router (vLLM built-in)
Writes topk_ids / topk_weights to SymmBuffer
│
▼
┌─────────────────────────────────────────┐
│ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py
│ │
│ 1. Read staged activation from buffer │
│ 2. L1 GEMM: gate_up_proj │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 output (6144-wide) │
│ 3. SiLU(gate) * up (activation) │
│ 4. stage_activation: BF16 → FP4 │ ← simple absmax quantize (needs work)
│ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 output (7168-wide) │
│ 6. Write to output tensor │
└─────────────────────────────────────────┘
```
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
### vLLM Startup Sequence (how our code plugs in)
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
```
1. vLLM engine init
└─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
└─ FlashInferCutlassNvFp4LinearKernel for linear layers
2. Model construction
└─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
Each layer has: attention + MoE block
MoE block has: shared experts + 256 routed experts
3. Weight loading
└─ 95 safetensor shards loaded
└─ weight, weight_scale, weight_scale_2 loaded per linear
4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN
└─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
└─ finalize_mega_moe_weights()
└─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
• Folds weight_scale_2 (global scale) into weight_scale (block scale)
• UE4M3 block-16 scales: 4 values packed per uint32
• Interleaves L1 (gate_up) weights for 2CTA UMMA
• Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank
5. SymmBuffer allocation
└─ symm_buffer.py: get_symm_buffer_for_nvfp4 mega_moe()
• Pre-allocates GPU buffers for:
- x: int8 packed E2M1 activations
- x_sf: uint32 packed UE4M3 activation scales
- topk_idx: int32 expert indices
- topk_weights: float32 routing weights
- buffer: BF16 all-reduce buffer
6. Profile run (warmup)
└─ First forward pass to allocate KV cache, etc.
└─ This is where the CUTLASS GEMM first executes
7. Ready to serve
```
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
---
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
## File Map
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
```
nvfp4_megamoe_kernel/
├── __init __ .py # Public API exports
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, nvfp4_mega_moe_l1/l2, stage_activation
├── weight_transform.py # Weight prep: fold global scale, pack UE4M3, interleave L1
├── symm_buffer.py # GPU buffer allocation for MoE dispatch
│
└── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel)
├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel
├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward)
├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop)
├── sf_layout.py # CUTLASS SF interleaved layout math
├── setup.py # Build config (nvcc, CUTLASS include paths)
├── build.sh # Build script
├── test_gemm.py # Standalone test
└── README.md
2026-05-14 11:23:32 +00:00
```
2026-05-14 12:48:08 +00:00
### What each file does (in call order)
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
| File | When it runs | What it does |
|------|-------------|--------------|
| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales, packs UE4M3 into uint32, interleaves L1 gate_up weights. Output: `((l1_w, l1_sf), (l2_w, l2_sf))` |
| `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. |
| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains `stage_activation` (BF16→FP4 quantize for L1→L2). |
| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Per-expert loop: gather tokens for each expert, call CUTLASS GEMM, scatter results with routing weights. |
| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side scale factor remap (row-major → CUTLASS interleaved layout). |
| `cutlass_nvfp4_gemm/sf_layout.py` | Build time / reference | Documents the CUTLASS SfAtom layout. Currently unused at runtime (remap is in CUDA). |
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
---
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
## Data Formats
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
### Weights
- **Packed E2M1** (`int8` ): 2 FP4 values per byte. Shape: `(E_per_rank, N, K//2)` , K-major layout.
- **UE4M3 block scales** (`float8_e4m3fn` ): 1 scale per 16 FP4 values (group_size=16). Shape: `(E_per_rank, N, K//16)` .
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
### Activations (after staging kernel)
- **Packed E2M1** (`int8` ): Shape: `(num_tokens, K//2)` .
- **UE4M3 scales** (`uint32` ): 4 UE4M3 values packed per uint32. Shape: `(num_tokens, K//64)` .
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
### GEMM dimensions (DeepSeek-V4-Pro)
- **L1 (gate_up_proj):** M× 6144× 7168 (per expert)
- **L2 (down_proj):** M× 7168× 3072 (per expert)
- 48 experts per rank (256 total / 8 ranks), top-6 routing
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
---
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
## Build & Deploy (B200)
2026-05-14 11:23:32 +00:00
```bash
2026-05-14 12:48:08 +00:00
# On B200 host — CUTLASS must be cloned and mounted
cd /root/nvidia-meeting/deepseek-v4-quant/
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
docker compose up -d
2026-05-14 11:23:32 +00:00
```
2026-05-14 12:48:08 +00:00
The CUTLASS extension builds inside the container during `pip install` of the nvfp4-megamoe-kernel package. It needs:
- CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
- CUTLASS headers at `/root/cutlass/include/`
- CCCL headers at `/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/`
- Device with SM100 compute capability (B200)
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
---
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
## Known Issues
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
1. **Shared experts bypassed ** — FlashInfer/DeepGEMM TF32 GEMM crashes the vLLM worker. Currently returning zeros for shared expert output. This produces garbage text.
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
2. **MoE dispatch is slow ** — `cutlass_grouped_nvfp4_gemm` uses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch.
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
3. **stage_activation is approximate ** — Simple per-token absmax quantization for L1→L2 re-quant. Should use proper E2M1 quantization matching vLLM's staging kernel.
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
4. **Scale factor remap adds overhead ** — GPU kernel remaps row-major → CUTLASS interleaved layout every GEMM call. Should pre-compute during weight transform.
2026-05-14 11:23:32 +00:00
2026-05-14 12:48:08 +00:00
---
2026-05-14 11:23:32 +00:00
## Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
2026-05-14 12:48:08 +00:00
| `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) |
| `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging |
| `MEGA_MOE_USE_CUTLASS` | 1 | Use CUTLASS path (always 1 now, TileLang removed) |
| `SKIP_ATTENTION` | 0 | Skip attention layers (debug) |