# nvfp4-megamoe-kernel Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100). Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUTLASS-based implementation that emits real `SM100_MMA_MXF4_SS` tensor core instructions. --- ## Architecture DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts. ### The MoE Forward Pass ``` Input hidden states (BF16) │ ▼ ┌─────────────────┐ │ Shared Experts │ ← vLLM native FlashInfer CUTLASS NVFP4 path │ (gate + up → │ (not our kernel) │ SiLU * up → │ │ down) │ └─────────────────┘ │ ▼ Staging Kernel (vLLM built-in) BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32) Writes to SymmBuffer.x / SymmBuffer.x_sf │ ▼ Router (vLLM built-in) Writes topk_ids / topk_weights to SymmBuffer │ ▼ ┌─────────────────────────────────────────────────┐ │ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py │ │ │ 1. Read staged activation from buffer │ │ 2. Build slot mapping (token, topk) → local │ │ expert, routing weight │ │ 3. L1 GEMM: gate_up_proj (slot-based) │ ← CUTLASS NVFP4 block-scaled │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 per-slot output (6144-wide) │ │ 4. SiLU(gate) * up PER SLOT │ │ (nonlinearity before combining paths) │ │ 5. stage_activation: BF16 → FP4 │ ← proper E2M1 quantization │ 6. L2 GEMM: down_proj (slot-based) │ ← CUTLASS NVFP4 block-scaled │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 per-slot output (7168-wide) │ │ 7. Final scatter: │ │ y.index_add_(slot_token, │ │ slot_weight * l2_slots) │ │ Routing weight applied ONCE at scatter │ └─────────────────────────────────────────────────┘ │ ▼ Cross-rank all-reduce (vLLM built-in) ``` ### Slot-Based Dispatch The kernel uses a **slot representation** instead of collapsing expert outputs early. A slot is one `(token, topk_expert)` pair. For a batch of T tokens with top-6 routing, there are up to 6T slots (fewer if some experts are out of the local rank's range). **Why slots?** Two bugs in the previous approach: 1. **SiLU after summing is mathematically wrong.** `silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ`. The nonlinearity must happen per-expert-path before combining. 2. **Routing weights applied twice.** The old grouped GEMM applied `topk_weights` in its scatter loop, and was called for both L1 and L2 — squaring the weights. The slot approach fixes both: SiLU+Mul happens per-slot, and routing weights are applied exactly once at the final `index_add_` scatter. ### SFB (Weight Scale Factors) — Remapped Per-Call, NOT Cached Weight scale factors (SFB) are remapped from row-major to CUTLASS interleaved layout on every GEMM call. This is a lightweight scatter kernel (~µs) and is NOT the bottleneck compared to the GEMM itself. ⚠️ **DO NOT ADD A PREPACK CACHE FOR SFB.** Previous attempts caused critical issues: | Problem | Impact | |---------|--------| | **OOM** | ~1.75 GiB per prepacked tensor × 61 MoE layers × 2 (L1+L2) = ~214 GiB — exceeds B200 capacity | | **Peak memory 2×** | `torch.stack` held all expert tensors + final stack simultaneously before LRU eviction | | **CUDA graph trap** | LRU eviction frees tensors that CUDA graphs still reference → use-after-free → silent corruption or crash | | **M-dependent layout** | `prepack_sfb(M=128)` assumed SFB layout size is M-independent (never verified). If wrong, entire prepack is invalid | | **Cross-layer cache collision** | Tag-based cache (`"l1"`/`"l2"`) returned layer N-1's data for layer N. Fixed with data_ptr key, but the cache itself was the root problem | The per-call remap costs microseconds. The cache cost was hours of debugging. Don't repeat this mistake. --- ### vLLM Startup Sequence (how our code plugs in) ``` 1. vLLM engine init └─ ModelOptNvFp4Config selected (NVFP4 quantization scheme) └─ FlashInferCutlassNvFp4LinearKernel for linear layers 2. Model construction └─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer Each layer has: attention + MoE block MoE block has: shared experts + 384 routed experts 3. Weight loading └─ 95 safetensor shards loaded └─ weight, weight_scale, weight_scale_2 loaded per linear 4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN └─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS └─ finalize_mega_moe_weights() └─ weight_transform.py: transform_nvfp4_weights_for_mega_moe() • Folds weight_scale_2 (global scale) into weight_scale (block scale) • UE4M3 block-16 scales: 4 values packed per uint32 • Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank 5. SymmBuffer allocation └─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe() • Pre-allocates GPU buffers for: - x: int8 packed E2M1 activations - x_sf: uint32 packed UE4M3 activation scales - topk_idx: int32 expert indices - topk_weights: float32 routing weights - buffer: BF16 all-reduce buffer 6. Profile run (warmup) └─ First forward pass to allocate KV cache, etc. └─ This is where the CUTLASS GEMM first executes └─ SFB weight scales remapped per-expert inside CUTLASS (no cache) 7. Ready to serve ``` --- ## File Map ``` nvfp4_megamoe_kernel/ ├── __init__.py # Public API exports ├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation ├── weight_transform.py # Weight prep: fold global scale, pack UE4M3 ├── symm_buffer.py # GPU buffer allocation for MoE dispatch │ └── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel) ├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap + prepack SFB + prepacked-SFB GEMM path ├── pytorch_binding.cpp # PyTorch C++ binding (forward, forward_prepacked_sfb, prepack_sfb) ├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (slot-based, per-expert loop) ├── sf_layout.py # CUTLASS SF layout reference docs ├── setup.py # Build config (nvcc, CUTLASS include paths) ├── build.sh # Build script ├── test_gemm.py # Standalone test └── README.md ``` ### What each file does (in call order) | File | When it runs | What it does | |------|-------------|--------------| | `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as `float8_e4m3fn` (not packed uint32). Output: `((l1_w, l1_sf), (l2_w, l2_sf))` | | `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. | | `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains `stage_activation` (BF16→FP4) and `unpack_ue4m3_u32`. NO prepack cache — SFB remapped per-call inside CUTLASS. | | `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (SFB remapped inside C extension), write results to slot buffer. No routing weights — caller handles scatter. | | `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA and SFB remap. | | `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). | --- ## Data Formats ### Weights - **Packed E2M1** (`int8`): 2 FP4 values per byte. Shape: `(E_per_rank, N, K//2)`, K-major layout. - **UE4M3 block scales** (`float8_e4m3fn`): 1 scale per 16 FP4 values (group_size=16). Shape: `(E_per_rank, N, K//16)`. Returned as `float8_e4m3fn` from `weight_transform.py` — NOT packed uint32. The CUTLASS GEMM consumes float8 directly. ### Activations (after staging kernel) - **Packed E2M1** (`int8`): Shape: `(num_tokens, K//2)`. - **UE4M3 scales** (`uint32`): 4 UE4M3 values packed per uint32. Shape: `(num_tokens, K//64)`. Unpacked to `float8_e4m3fn` via `unpack_ue4m3_u32` before reaching the CUTLASS GEMM. ### GEMM dimensions (DeepSeek-V4-Pro) - **L1 (gate_up_proj):** M×6144×7168 (per expert) - **L2 (down_proj):** M×7168×3072 (per expert) - 48 experts per rank (384 total / 8 ranks), top-6 routing --- ## CUTLASS Scale Factor Remap CUTLASS's `Sm1xxBlockScaledConfig` expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is: ``` Atom Shape: Shape, Shape<16, 4>> Atom Stride: Stride, Stride<0, 1>> Tiling: Step<_2, _1> (M tiled with step 2, K with step 1) ``` Our source data is row-major `(M, K_sf)` where `K_sf = K / 16`. The remap kernel (`remap_sf_to_cutlass_kernel` in `cutlass_nvfp4_gemm.cu`) converts from row-major to CUTLASS's interleaved layout. ### How the remap works The kernel iterates over CUTLASS destination indices, uses `cute::idx2crd` to get the hierarchical coordinate, then `cute::flatten` to get a flat tuple of 8 sub-indices. From those, we extract logical `(m, k_sf)` and read from the row-major source. ### Flattened coordinate decomposition (flat_rank=8) From the SfAtom layout with Step<_2, _1> tiling, `flatten(idx2crd(idx, ...))` produces 8 values: ``` f0 = inner_m (0..31) — varies fastest within M atom f1 = sub_m (0..3) — second M sub-coordinate f2 = tile_m (0..) — M tile index f3 = step_m stride — degenerate (always = sfa_size, not a coordinate) f4 = sub_k (0..3) — K sub-coordinate within atom f5 = tile_k (0..) — K tile index f6 = 0 — unused f7 = 0 — unused ``` #### Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256) | idx | f0 | f1 | f2 | f3 | f4 | f5 | f6 | f7 | | ----- | --- | --- | --- | --- | --- | --- | --- | --- | | 0 | 0 | 0 | 0 | T | 0 | 0 | 0 | 0 | | 1 | 0 | 0 | 0 | T | 1 | 0 | 0 | 0 | | 4 | 0 | 1 | 0 | T | 0 | 0 | 0 | 0 | | 16 | 1 | 0 | 0 | T | 0 | 0 |0 | 0 | | 511 | 31 | 3 | 0 | T | 3 | 0 | 0 | 0 | | 512 | 0 | 0 | 0 | T | 0 | 1 | 0 | 0 | | 1024 | 0 | 0 | 0 | T | 0 | 2 | 0 | 0 | | 2048 | 0 | 0 | 0 | T | 0 | 4 | 0 | 0 | | 4096 | 0 | 0 | 0 | T | 0 | 8 | 0 | 0 | | 8192 | 0 | 0 | 0 | T | 0 | 16 | 0 | 0 | | 65536 | 0 | 0 | 1 | T | 0 | 16 | 0 | 0 | | 131072 | 0 | 0 | 2 | T | 0 | 32 | 0 | 0 | #### Extraction formula CuTe uses "first sub varies fastest" for `Shape<32, 4>`: ```cpp m = f0 + f1 * 32 + f2 * 128; k_sf = f4 + f5 * 4; ``` This was verified with 6 independent probes: | Probe | Source | Expected | Result | |-------|--------|----------|--------| | SFA[1, 0] = 2.0 | row 1 changes | ✅ only row 1 | Confirms f0 term | | SFA[32, 0] = 2.0 | row 32 changes | ✅ only row 32 | Confirms f1*32, rules out f0*4+f1 | | SFA[128, 0] = 2.0 | row 128 changes | ✅ only row 128 | Confirms f2*128 | | SFA[0, 1] = 2.0 | row 0 changes (k=1) | ✅ only row 0 | Confirms f4 term | | SFA[0, 4] = 2.0 | row 0 changes (k=4) | ✅ only row 0 | Confirms f5*4 term | | SFA[0, 100] = 2.0 | row 0 changes (k=100) | ✅ only row 0 | Confirms tile-overflow range | #### Why the previous remap was broken The previous code used `cute::get<0>(flat)` and `cute::get<1>(flat)` to extract (m, k). Since flatten produces `(inner_m, sub_m, tile_m, ...)` in order, `get<0>` and `get<1>` are both **M sub-indices** — they carry no K information. This caused only `k_group=0` to work; all other K-groups were silently mapped to the wrong source offset. Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than `M * K_sf`. Unmapped padding slots reading garbage caused sporadic wrong results. --- ## Bugs Found & Fixed ### 1. unpack_ue4m3_u32: value cast vs bit reinterpret **File:** `nvfp4_mega_moe.py` **Bug:** `(x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)` converts integer 63 → float8(63.0). **Fix:** `(x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)` reinterprets bit pattern 0x3F → float8(~0.984). **Also:** `uint32` lacks CUDA bitwise ops — cast to `int32` first. **Impact:** Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe. ### 2. stage_activation: three independent bugs **File:** `nvfp4_moe.py` **Bug A:** `clamp(0, 15)` zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag). **Bug B:** Stored `block_max` but divided by `block_max/6.0` → stored scale was 6× too large. **Bug C:** Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2. **Fix:** Rewrote with proper nearest-neighbor E2M1 quantization. **Impact:** Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top. ### 3. _fold_global_scale: logical_widths branch **File:** `weight_transform.py` **Bug:** `logical_widths=[3072, 3072]` caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded. **Fix:** Removed the `logical_widths` branch entirely. The `else` branch correctly broadcasts each expert's own `(E, 1)` global scale across `(E, N, K//16)`. ### 4. L1 weight interleave removed (transpose still needed) **File:** `weight_transform.py` **Bug:** `_interleave_l1_weights` assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat `[gate; up]` along the output dim, and our CUTLASS kernel uses `ClusterShape<1, 1, 1>`. **Fix:** Removed the interleave function. Weights still need a transpose from checkpoint layout `(N, K_half)` row-major to CUTLASS layout `(K_half, N)` column-major — this is standard row→column conversion, not interleaving. Both L1 and L2 weights and scales are transposed. ### 5. SF remap: idx2crd+flatten coordinate extraction **File:** `cutlass_nvfp4_gemm.cu` **Bug:** `cute::flatten(coord)` produces 8 sub-indices (flat_rank=8). `get<0>` and `get<1>` are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong. **Fix:** Correct extraction: `m = f0 + f1*32 + f2*128`, `k_sf = f4 + f5*4`. Zero-init dest buffer before remap. **Diagnostic trail:** Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula. ### 6. SiLU after summing expert paths (math error) **File:** `nvfp4_mega_moe.py` **Bug:** The old grouped GEMM collapsed expert outputs into a weighted sum, then applied SiLU+Mul on the sum. `silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ`. The nonlinearity must happen per-expert-path. **Fix:** Slot-based dispatch — L1 GEMM returns per-slot output, SiLU+Mul applied per-slot, L2 GEMM per-slot, routing weights applied once at final `index_add_` scatter. ### 7. Routing weights applied twice **File:** `cutlass_nvfp4_gemm/kernel.py` **Bug:** `cutlass_grouped_nvfp4_gemm` applied `topk_weights` in its scatter loop. Called for both L1 and L2, each expert's contribution was scaled by `topk_weight²`. **Fix:** GEMM returns per-slot results with no routing weights. Single `y.index_add_(0, slot_token, slot_weight * l2_slots)` at the end. ### Diagnostic: constant-scale test (smoking gun for SF bugs) When all scale factors are set to UE4M3(1.0): - **Cosine = 1.0000, MSE = 0.19** (expected FP4 quantization noise) With real (variable) scale factors and the broken remap: - **Cosine = 0.83** → scales are misaligned, not fundamentally broken After the fix with correct coordinate extraction: - **Cosine = 1.0000, MSE = 0.0** → perfect match with dequantized reference --- ## Build & Deploy (B200) ```bash # On B200 host — CUTLASS must be cloned and mounted cd /root/nvidia-meeting/deepseek-v4-quant/ # Rebuild container (CUTLASS is host-mounted at /root/cutlass) KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache docker compose up -d ``` The CUTLASS extension builds inside the container during `pip install` of the nvfp4-megamoe-kernel package. It needs: - CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image) - CUTLASS headers at `/root/cutlass/include/` - CCCL headers at `/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/` - Device with SM100 compute capability (B200) --- ## Known Issues / TODO 1. ~~**MoE dispatch is slow**~~ — Fixed. Slot-based `index_add_` replaces the Python double loop over tokens×topk. Routing weights applied once at final scatter. 2. **stage_activation is Python** — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging. 3. ~~**SF remap allocates every call**~~ — Fixed. SFB weight scales are prepacked into CUTLASS layout once (lazy, cached per layer). Only SFA (activation scales) remapped dynamically. 4. **Per-expert GEMM dispatch is serial Python loop** — The `cutlass_grouped_nvfp4_gemm` iterates over 48 experts in a Python `for` loop. Each iteration launches one CUTLASS GEMM. Could benefit from a true grouped GEMM kernel or CUDA-side expert dispatch. --- ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| | `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) | | `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging | | `SKIP_ATTENTION` | 0 | Skip attention layers (debug) | --- ## Repos - **Kernel:** `sweetapi.com/biondizzle/nvfp4-megamoe-kernel` (branch: master) - **Deployment:** `sweetapi.com/biondizzle/deepseek-v4-quant` (branch: modelopt-nvfp4) - **Local:** `~/dev/nvfp4-megamoe-kernel/`, `~/dev/deepseek-v4-quant/`