biondizzle f6fd549800 fix: restore col_major_src handling for SFB source layout
SFB scales arrive as (K_sf, N) row-major after transpose+contiguous
in weight_transform.py. The col_major_src flag correctly describes
this. Don't assume both sources are (MN, K_sf).
2026-05-15 21:19:58 +00:00
2026-05-15 19:58:57 +00:00

nvfp4-megamoe-kernel

Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).

Replaces the broken fp8_nvfp4_mega_moe kernel from DeepGEMM with a working CUTLASS-based implementation that emits real SM100_MMA_MXF4_SS tensor core instructions.


Architecture

DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts.

The MoE Forward Pass

Input hidden states (BF16)
        │
        ▼
┌─────────────────┐
│  Shared Experts │  ← vLLM native FlashInfer CUTLASS NVFP4 path
│  (gate + up →  │     (not our kernel)
│   SiLU * up →   │
│   down)         │
└─────────────────┘
        │
        ▼
  Staging Kernel (vLLM built-in)
  BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
  Writes to SymmBuffer.x / SymmBuffer.x_sf
        │
        ▼
  Router (vLLM built-in)
  Writes topk_ids / topk_weights to SymmBuffer
        │
        ▼
┌─────────────────────────────────────────────────┐
│          nvfp4_mega_moe_full                    │  ← nvfp4_mega_moe.py
│                                                 │
│  1. Read staged activation from buffer          │
│  2. Build slot mapping (token, topk) → local    │
│     expert, routing weight                      │
│  3. L1 GEMM: gate_up_proj (slot-based)         │  ← CUTLASS NVFP4 block-scaled
│     E2M1 × E2M1 + UE4M3 scales                 │    SM100_MMA_MXF4_SS PTX
│     → BF16 per-slot output (6144-wide)         │
│  4. SiLU(gate) * up PER SLOT                   │
│     (nonlinearity before combining paths)       │
│  5. stage_activation: BF16 → FP4               │  ← proper E2M1 quantization
│  6. L2 GEMM: down_proj (slot-based)            │  ← CUTLASS NVFP4 block-scaled
│     E2M1 × E2M1 + UE4M3 scales                 │    SM100_MMA_MXF4_SS PTX
│     → BF16 per-slot output (7168-wide)         │
│  7. Final scatter:                              │
│     y.index_add_(slot_token,                    │
│       slot_weight * l2_slots)                   │
│     Routing weight applied ONCE at scatter      │
└─────────────────────────────────────────────────┘
        │
        ▼
  Cross-rank all-reduce (vLLM built-in)

Slot-Based Dispatch

The kernel uses a slot representation instead of collapsing expert outputs early. A slot is one (token, topk_expert) pair. For a batch of T tokens with top-6 routing, there are up to 6T slots (fewer if some experts are out of the local rank's range).

Why slots? Two bugs in the previous approach:

  1. SiLU after summing is mathematically wrong. silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ. The nonlinearity must happen per-expert-path before combining.

  2. Routing weights applied twice. The old grouped GEMM applied topk_weights in its scatter loop, and was called for both L1 and L2 — squaring the weights.

The slot approach fixes both: SiLU+Mul happens per-slot, and routing weights are applied exactly once at the final index_add_ scatter.

SFB (Weight Scale Factors) — Remapped Per-Call, NOT Cached

Weight scale factors (SFB) are remapped from row-major to CUTLASS interleaved layout on every GEMM call. This is a lightweight scatter kernel (~µs) and is NOT the bottleneck compared to the GEMM itself.

⚠️ DO NOT ADD A PREPACK CACHE FOR SFB. Previous attempts caused critical issues:

Problem Impact
OOM ~1.75 GiB per prepacked tensor × 61 MoE layers × 2 (L1+L2) = ~214 GiB — exceeds B200 capacity
Peak memory 2× torch.stack held all expert tensors + final stack simultaneously before LRU eviction
CUDA graph trap LRU eviction frees tensors that CUDA graphs still reference → use-after-free → silent corruption or crash
M-dependent layout prepack_sfb(M=128) assumed SFB layout size is M-independent (never verified). If wrong, entire prepack is invalid
Cross-layer cache collision Tag-based cache ("l1"/"l2") returned layer N-1's data for layer N. Fixed with data_ptr key, but the cache itself was the root problem

The per-call remap costs microseconds. The cache cost was hours of debugging. Don't repeat this mistake.


vLLM Startup Sequence (how our code plugs in)

1. vLLM engine init
   └─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
   └─ FlashInferCutlassNvFp4LinearKernel for linear layers

2. Model construction
   └─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
       Each layer has: attention + MoE block
       MoE block has: shared experts + 384 routed experts

3. Weight loading
   └─ 95 safetensor shards loaded
   └─ weight, weight_scale, weight_scale_2 loaded per linear

4. process_weights_after_loading  ← THIS IS WHERE WE HOOK IN
   └─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
   └─ finalize_mega_moe_weights()
       └─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
           • Folds weight_scale_2 (global scale) into weight_scale (block scale)
           • UE4M3 block-16 scales: 4 values packed per uint32
           • Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank

5. SymmBuffer allocation
   └─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe()
       • Pre-allocates GPU buffers for:
         - x: int8 packed E2M1 activations
         - x_sf: uint32 packed UE4M3 activation scales
         - topk_idx: int32 expert indices
         - topk_weights: float32 routing weights
         - buffer: BF16 all-reduce buffer

6. Profile run (warmup)
   └─ First forward pass to allocate KV cache, etc.
   └─ This is where the CUTLASS GEMM first executes
   └─ SFB weight scales remapped per-expert inside CUTLASS (no cache)

7. Ready to serve

File Map

nvfp4_megamoe_kernel/
├── __init__.py              # Public API exports
├── nvfp4_mega_moe.py       # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation
├── weight_transform.py     # Weight prep: fold global scale, pack UE4M3
├── symm_buffer.py          # GPU buffer allocation for MoE dispatch
│
└── cutlass_nvfp4_gemm/     # CUTLASS CUDA extension (the actual hardware kernel)
    ├── cutlass_nvfp4_gemm.cu   # CUDA: CUTLASS GEMM + SF remap + prepack SFB + prepacked-SFB GEMM path
    ├── pytorch_binding.cpp     # PyTorch C++ binding (forward, forward_prepacked_sfb, prepack_sfb)
    ├── kernel.py               # Python: cutlass_grouped_nvfp4_gemm (slot-based, per-expert loop)
    ├── sf_layout.py            # CUTLASS SF layout reference docs
    ├── setup.py                # Build config (nvcc, CUTLASS include paths)
    ├── build.sh                # Build script
    ├── test_gemm.py            # Standalone test
    └── README.md

What each file does (in call order)

File When it runs What it does
weight_transform.py Once at startup (weight loading) Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as float8_e4m3fn (not packed uint32). Output: ((l1_w, l1_sf), (l2_w, l2_sf))
symm_buffer.py Once at startup (buffer alloc) Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes.
nvfp4_mega_moe.py Every forward pass Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains stage_activation (BF16→FP4) and unpack_ue4m3_u32. NO prepack cache — SFB remapped per-call inside CUTLASS.
cutlass_nvfp4_gemm/kernel.py Every forward pass (called by nvfp4_mega_moe) Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (SFB remapped inside C extension), write results to slot buffer. No routing weights — caller handles scatter.
cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu Every forward pass (CUDA kernel) The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA and SFB remap.
cutlass_nvfp4_gemm/sf_layout.py Reference only Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA).

Data Formats

Weights

  • Packed E2M1 (int8): 2 FP4 values per byte. Shape: (E_per_rank, N, K//2), K-major layout.
  • UE4M3 block scales (float8_e4m3fn): 1 scale per 16 FP4 values (group_size=16). Shape: (E_per_rank, N, K//16). Returned as float8_e4m3fn from weight_transform.py — NOT packed uint32. The CUTLASS GEMM consumes float8 directly.

Activations (after staging kernel)

  • Packed E2M1 (int8): Shape: (num_tokens, K//2).
  • UE4M3 scales (uint32): 4 UE4M3 values packed per uint32. Shape: (num_tokens, K//64). Unpacked to float8_e4m3fn via unpack_ue4m3_u32 before reaching the CUTLASS GEMM.

GEMM dimensions (DeepSeek-V4-Pro)

  • L1 (gate_up_proj): M×6144×7168 (per expert)
  • L2 (down_proj): M×7168×3072 (per expert)
  • 48 experts per rank (384 total / 8 ranks), top-6 routing

CUTLASS Scale Factor Remap

CUTLASS's Sm1xxBlockScaledConfig expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is:

Atom Shape:  Shape<Shape<32, 4>, Shape<16, 4>>
Atom Stride: Stride<Stride<16, 4>, Stride<0, 1>>
Tiling:      Step<_2, _1>  (M tiled with step 2, K with step 1)

Our source data is row-major (M, K_sf) where K_sf = K / 16. The remap kernel (remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu) converts from row-major to CUTLASS's interleaved layout.

How the remap works

The kernel iterates over CUTLASS destination indices, uses cute::idx2crd to get the hierarchical coordinate, then cute::flatten to get a flat tuple of 8 sub-indices. From those, we extract logical (m, k_sf) and read from the row-major source.

Flattened coordinate decomposition (flat_rank=8)

From the SfAtom layout with Step<_2, _1> tiling, flatten(idx2crd(idx, ...)) produces 8 values:

f0 = inner_m  (0..31)  — varies fastest within M atom
f1 = sub_m    (0..3)   — second M sub-coordinate
f2 = tile_m   (0..)    — M tile index
f3 = step_m stride     — degenerate (always = sfa_size, not a coordinate)
f4 = sub_k    (0..3)   — K sub-coordinate within atom
f5 = tile_k   (0..)    — K tile index
f6 = 0                  — unused
f7 = 0                  — unused

Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256)

idx f0 f1 f2 f3 f4 f5 f6 f7
0 0 0 0 T 0 0 0 0
1 0 0 0 T 1 0 0 0
4 0 1 0 T 0 0 0 0
16 1 0 0 T 0 0 0 0
511 31 3 0 T 3 0 0 0
512 0 0 0 T 0 1 0 0
1024 0 0 0 T 0 2 0 0
2048 0 0 0 T 0 4 0 0
4096 0 0 0 T 0 8 0 0
8192 0 0 0 T 0 16 0 0
65536 0 0 1 T 0 16 0 0
131072 0 0 2 T 0 32 0 0

Extraction formula

CuTe uses "first sub varies fastest" for Shape<32, 4>:

m    = f0 + f1 * 32 + f2 * 128;
k_sf = f4 + f5 * 4;

This was verified with 6 independent probes:

Probe Source Expected Result
SFA[1, 0] = 2.0 row 1 changes only row 1 Confirms f0 term
SFA[32, 0] = 2.0 row 32 changes only row 32 Confirms f132, rules out f04+f1
SFA[128, 0] = 2.0 row 128 changes only row 128 Confirms f2*128
SFA[0, 1] = 2.0 row 0 changes (k=1) only row 0 Confirms f4 term
SFA[0, 4] = 2.0 row 0 changes (k=4) only row 0 Confirms f5*4 term
SFA[0, 100] = 2.0 row 0 changes (k=100) only row 0 Confirms tile-overflow range

Why the previous remap was broken

The previous code used cute::get<0>(flat) and cute::get<1>(flat) to extract (m, k). Since flatten produces (inner_m, sub_m, tile_m, ...) in order, get<0> and get<1> are both M sub-indices — they carry no K information. This caused only k_group=0 to work; all other K-groups were silently mapped to the wrong source offset.

Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than M * K_sf. Unmapped padding slots reading garbage caused sporadic wrong results.


Bugs Found & Fixed

1. unpack_ue4m3_u32: value cast vs bit reinterpret

File: nvfp4_mega_moe.py
Bug: (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) converts integer 63 → float8(63.0).
Fix: (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) reinterprets bit pattern 0x3F → float8(~0.984).
Also: uint32 lacks CUDA bitwise ops — cast to int32 first.
Impact: Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe.

2. stage_activation: three independent bugs

File: nvfp4_moe.py
Bug A: clamp(0, 15) zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag).
Bug B: Stored block_max but divided by block_max/6.0 → stored scale was 6× too large.
Bug C: Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2.
Fix: Rewrote with proper nearest-neighbor E2M1 quantization.
Impact: Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top.

3. _fold_global_scale: logical_widths branch

File: weight_transform.py
Bug: logical_widths=[3072, 3072] caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded.
Fix: Removed the logical_widths branch entirely. The else branch correctly broadcasts each expert's own (E, 1) global scale across (E, N, K//16).

4. L1 weight interleave removed (transpose still needed)

File: weight_transform.py
Bug: _interleave_l1_weights assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat [gate; up] along the output dim, and our CUTLASS kernel uses ClusterShape<1, 1, 1>.
Fix: Removed the interleave function. Weights still need a transpose from checkpoint layout (N, K_half) row-major to CUTLASS layout (K_half, N) column-major — this is standard row→column conversion, not interleaving. Both L1 and L2 weights and scales are transposed.

5. SF remap: idx2crd+flatten coordinate extraction

File: cutlass_nvfp4_gemm.cu
Bug: cute::flatten(coord) produces 8 sub-indices (flat_rank=8). get<0> and get<1> are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong.
Fix: Correct extraction: m = f0 + f1*32 + f2*128, k_sf = f4 + f5*4. Zero-init dest buffer before remap.
Diagnostic trail: Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula.

6. SiLU after summing expert paths (math error)

File: nvfp4_mega_moe.py
Bug: The old grouped GEMM collapsed expert outputs into a weighted sum, then applied SiLU+Mul on the sum. silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ. The nonlinearity must happen per-expert-path.
Fix: Slot-based dispatch — L1 GEMM returns per-slot output, SiLU+Mul applied per-slot, L2 GEMM per-slot, routing weights applied once at final index_add_ scatter.

7. Routing weights applied twice

File: cutlass_nvfp4_gemm/kernel.py
Bug: cutlass_grouped_nvfp4_gemm applied topk_weights in its scatter loop. Called for both L1 and L2, each expert's contribution was scaled by topk_weight².
Fix: GEMM returns per-slot results with no routing weights. Single y.index_add_(0, slot_token, slot_weight * l2_slots) at the end.

Diagnostic: constant-scale test (smoking gun for SF bugs)

When all scale factors are set to UE4M3(1.0):

  • Cosine = 1.0000, MSE = 0.19 (expected FP4 quantization noise)

With real (variable) scale factors and the broken remap:

  • Cosine = 0.83 → scales are misaligned, not fundamentally broken

After the fix with correct coordinate extraction:

  • Cosine = 1.0000, MSE = 0.0 → perfect match with dequantized reference

Build & Deploy (B200)

# On B200 host — CUTLASS must be cloned and mounted
cd /root/nvidia-meeting/deepseek-v4-quant/

# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
docker compose up -d

The CUTLASS extension builds inside the container during pip install of the nvfp4-megamoe-kernel package. It needs:

  • CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
  • CUTLASS headers at /root/cutlass/include/
  • CCCL headers at /usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/
  • Device with SM100 compute capability (B200)

Known Issues / TODO

  1. MoE dispatch is slow — Fixed. Slot-based index_add_ replaces the Python double loop over tokens×topk. Routing weights applied once at final scatter.

  2. stage_activation is Python — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging.

  3. SF remap allocates every call — Fixed. SFB weight scales are prepacked into CUTLASS layout once (lazy, cached per layer). Only SFA (activation scales) remapped dynamically.

  4. Per-expert GEMM dispatch is serial Python loop — The cutlass_grouped_nvfp4_gemm iterates over 48 experts in a Python for loop. Each iteration launches one CUTLASS GEMM. Could benefit from a true grouped GEMM kernel or CUDA-side expert dispatch.


Environment Variables

Variable Default Description
MEGA_MOE_STATIC 0 Set to 1 to skip MoE kernel entirely (return zeros)
MEGA_MOE_DEBUG 0 Set to 1 for verbose logging
SKIP_ATTENTION 0 Skip attention layers (debug)

Repos

  • Kernel: sweetapi.com/biondizzle/nvfp4-megamoe-kernel (branch: master)
  • Deployment: sweetapi.com/biondizzle/deepseek-v4-quant (branch: modelopt-nvfp4)
  • Local: ~/dev/nvfp4-megamoe-kernel/, ~/dev/deepseek-v4-quant/
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%