biondizzle c3841983a0 fix: SF remap uses cute::cosize() instead of cute::size()
The comment explicitly warned about this: allocation uses cosize (physical
size including tile padding) but the iteration bound used size (logical size).
This meant padding positions in the CUTLASS SF layout were never written,
leaving them as zero instead of their actual SF values. With uniform data
(all-ones), all SF values are the same so the bug was invisible. With
random data, different SF values are needed at different positions and
the missing writes corrupt the result.
2026-05-15 18:52:23 +00:00

nvfp4-megamoe-kernel

Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).

Replaces the broken fp8_nvfp4_mega_moe kernel from DeepGEMM with a working CUTLASS-based implementation that emits real SM100_MMA_MXF4_SS tensor core instructions.


Architecture

DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts.

The MoE Forward Pass

Input hidden states (BF16)
        │
        ▼
┌─────────────────┐
│  Shared Experts │  ← vLLM native FlashInfer CUTLASS NVFP4 path
│  (gate + up →  │     (not our kernel)
│   SiLU * up →   │
│   down)         │
└─────────────────┘
        │
        ▼
  Staging Kernel (vLLM built-in)
  BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
  Writes to SymmBuffer.x / SymmBuffer.x_sf
        │
        ▼
  Router (vLLM built-in)
  Writes topk_ids / topk_weights to SymmBuffer
        │
        ▼
┌─────────────────────────────────────────────────┐
│          nvfp4_mega_moe_full                    │  ← nvfp4_mega_moe.py
│                                                 │
│  1. Read staged activation from buffer          │
│  2. Build slot mapping (token, topk) → local    │
│     expert, routing weight                      │
│  3. L1 GEMM: gate_up_proj (slot-based)         │  ← CUTLASS NVFP4 block-scaled
│     E2M1 × E2M1 + UE4M3 scales                 │    SM100_MMA_MXF4_SS PTX
│     → BF16 per-slot output (6144-wide)         │
│  4. SiLU(gate) * up PER SLOT                   │
│     (nonlinearity before combining paths)       │
│  5. stage_activation: BF16 → FP4               │  ← proper E2M1 quantization
│  6. L2 GEMM: down_proj (slot-based)            │  ← CUTLASS NVFP4 block-scaled
│     E2M1 × E2M1 + UE4M3 scales                 │    SM100_MMA_MXF4_SS PTX
│     → BF16 per-slot output (7168-wide)         │
│  7. Final scatter:                              │
│     y.index_add_(slot_token,                    │
│       slot_weight * l2_slots)                   │
│     Routing weight applied ONCE at scatter      │
└─────────────────────────────────────────────────┘
        │
        ▼
  Cross-rank all-reduce (vLLM built-in)

Slot-Based Dispatch

The kernel uses a slot representation instead of collapsing expert outputs early. A slot is one (token, topk_expert) pair. For a batch of T tokens with top-6 routing, there are up to 6T slots (fewer if some experts are out of the local rank's range).

Why slots? Two bugs in the previous approach:

  1. SiLU after summing is mathematically wrong. silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ. The nonlinearity must happen per-expert-path before combining.

  2. Routing weights applied twice. The old grouped GEMM applied topk_weights in its scatter loop, and was called for both L1 and L2 — squaring the weights.

The slot approach fixes both: SiLU+Mul happens per-slot, and routing weights are applied exactly once at the final index_add_ scatter.

SFB (Weight Scale Factors) — Remapped Per-Call, NOT Cached

Weight scale factors (SFB) are remapped from row-major to CUTLASS interleaved layout on every GEMM call. This is a lightweight scatter kernel (~µs) and is NOT the bottleneck compared to the GEMM itself.

⚠️ DO NOT ADD A PREPACK CACHE FOR SFB. Previous attempts caused critical issues:

Problem Impact
OOM ~1.75 GiB per prepacked tensor × 61 MoE layers × 2 (L1+L2) = ~214 GiB — exceeds B200 capacity
Peak memory 2× torch.stack held all expert tensors + final stack simultaneously before LRU eviction
CUDA graph trap LRU eviction frees tensors that CUDA graphs still reference → use-after-free → silent corruption or crash
M-dependent layout prepack_sfb(M=128) assumed SFB layout size is M-independent (never verified). If wrong, entire prepack is invalid
Cross-layer cache collision Tag-based cache ("l1"/"l2") returned layer N-1's data for layer N. Fixed with data_ptr key, but the cache itself was the root problem

The per-call remap costs microseconds. The cache cost was hours of debugging. Don't repeat this mistake.


vLLM Startup Sequence (how our code plugs in)

1. vLLM engine init
   └─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
   └─ FlashInferCutlassNvFp4LinearKernel for linear layers

2. Model construction
   └─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
       Each layer has: attention + MoE block
       MoE block has: shared experts + 384 routed experts

3. Weight loading
   └─ 95 safetensor shards loaded
   └─ weight, weight_scale, weight_scale_2 loaded per linear

4. process_weights_after_loading  ← THIS IS WHERE WE HOOK IN
   └─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
   └─ finalize_mega_moe_weights()
       └─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
           • Folds weight_scale_2 (global scale) into weight_scale (block scale)
           • UE4M3 block-16 scales: 4 values packed per uint32
           • Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank

5. SymmBuffer allocation
   └─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe()
       • Pre-allocates GPU buffers for:
         - x: int8 packed E2M1 activations
         - x_sf: uint32 packed UE4M3 activation scales
         - topk_idx: int32 expert indices
         - topk_weights: float32 routing weights
         - buffer: BF16 all-reduce buffer

6. Profile run (warmup)
   └─ First forward pass to allocate KV cache, etc.
   └─ This is where the CUTLASS GEMM first executes
   └─ SFB weight scales remapped per-expert inside CUTLASS (no cache)

7. Ready to serve

File Map

nvfp4_megamoe_kernel/
├── __init__.py              # Public API exports
├── nvfp4_mega_moe.py       # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation
├── weight_transform.py     # Weight prep: fold global scale, pack UE4M3
├── symm_buffer.py          # GPU buffer allocation for MoE dispatch
│
└── cutlass_nvfp4_gemm/     # CUTLASS CUDA extension (the actual hardware kernel)
    ├── cutlass_nvfp4_gemm.cu   # CUDA: CUTLASS GEMM + SF remap + prepack SFB + prepacked-SFB GEMM path
    ├── pytorch_binding.cpp     # PyTorch C++ binding (forward, forward_prepacked_sfb, prepack_sfb)
    ├── kernel.py               # Python: cutlass_grouped_nvfp4_gemm (slot-based, per-expert loop)
    ├── sf_layout.py            # CUTLASS SF layout reference docs
    ├── setup.py                # Build config (nvcc, CUTLASS include paths)
    ├── build.sh                # Build script
    ├── test_gemm.py            # Standalone test
    └── README.md

What each file does (in call order)

File When it runs What it does
weight_transform.py Once at startup (weight loading) Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as float8_e4m3fn (not packed uint32). Output: ((l1_w, l1_sf), (l2_w, l2_sf))
symm_buffer.py Once at startup (buffer alloc) Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes.
nvfp4_mega_moe.py Every forward pass Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains stage_activation (BF16→FP4) and unpack_ue4m3_u32. NO prepack cache — SFB remapped per-call inside CUTLASS.
cutlass_nvfp4_gemm/kernel.py Every forward pass (called by nvfp4_mega_moe) Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (SFB remapped inside C extension), write results to slot buffer. No routing weights — caller handles scatter.
cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu Every forward pass (CUDA kernel) The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA and SFB remap.
cutlass_nvfp4_gemm/sf_layout.py Reference only Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA).

Data Formats

Weights

  • Packed E2M1 (int8): 2 FP4 values per byte. Shape: (E_per_rank, N, K//2), K-major layout.
  • UE4M3 block scales (float8_e4m3fn): 1 scale per 16 FP4 values (group_size=16). Shape: (E_per_rank, N, K//16). Returned as float8_e4m3fn from weight_transform.py — NOT packed uint32. The CUTLASS GEMM consumes float8 directly.

Activations (after staging kernel)

  • Packed E2M1 (int8): Shape: (num_tokens, K//2).
  • UE4M3 scales (uint32): 4 UE4M3 values packed per uint32. Shape: (num_tokens, K//64). Unpacked to float8_e4m3fn via unpack_ue4m3_u32 before reaching the CUTLASS GEMM.

GEMM dimensions (DeepSeek-V4-Pro)

  • L1 (gate_up_proj): M×6144×7168 (per expert)
  • L2 (down_proj): M×7168×3072 (per expert)
  • 48 experts per rank (384 total / 8 ranks), top-6 routing

CUTLASS Scale Factor Remap

CUTLASS's Sm1xxBlockScaledConfig expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is:

Atom Shape:  Shape<Shape<32, 4>, Shape<16, 4>>
Atom Stride: Stride<Stride<16, 4>, Stride<0, 1>>
Tiling:      Step<_2, _1>  (M tiled with step 2, K with step 1)

Our source data is row-major (M, K_sf) where K_sf = K / 16. The remap kernel (remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu) converts from row-major to CUTLASS's interleaved layout.

How the remap works

The kernel iterates over CUTLASS destination indices, uses cute::idx2crd to get the hierarchical coordinate, then cute::flatten to get a flat tuple of 8 sub-indices. From those, we extract logical (m, k_sf) and read from the row-major source.

Flattened coordinate decomposition (flat_rank=8)

From the SfAtom layout with Step<_2, _1> tiling, flatten(idx2crd(idx, ...)) produces 8 values:

f0 = inner_m  (0..31)  — varies fastest within M atom
f1 = sub_m    (0..3)   — second M sub-coordinate
f2 = tile_m   (0..)    — M tile index
f3 = step_m stride     — degenerate (always = sfa_size, not a coordinate)
f4 = sub_k    (0..3)   — K sub-coordinate within atom
f5 = tile_k   (0..)    — K tile index
f6 = 0                  — unused
f7 = 0                  — unused

Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256)

idx f0 f1 f2 f3 f4 f5 f6 f7
0 0 0 0 T 0 0 0 0
1 0 0 0 T 1 0 0 0
4 0 1 0 T 0 0 0 0
16 1 0 0 T 0 0 0 0
511 31 3 0 T 3 0 0 0
512 0 0 0 T 0 1 0 0
1024 0 0 0 T 0 2 0 0
2048 0 0 0 T 0 4 0 0
4096 0 0 0 T 0 8 0 0
8192 0 0 0 T 0 16 0 0
65536 0 0 1 T 0 16 0 0
131072 0 0 2 T 0 32 0 0

Extraction formula

CuTe uses "first sub varies fastest" for Shape<32, 4>:

m    = f0 + f1 * 32 + f2 * 128;
k_sf = f4 + f5 * 4;

This was verified with 6 independent probes:

Probe Source Expected Result
SFA[1, 0] = 2.0 row 1 changes only row 1 Confirms f0 term
SFA[32, 0] = 2.0 row 32 changes only row 32 Confirms f132, rules out f04+f1
SFA[128, 0] = 2.0 row 128 changes only row 128 Confirms f2*128
SFA[0, 1] = 2.0 row 0 changes (k=1) only row 0 Confirms f4 term
SFA[0, 4] = 2.0 row 0 changes (k=4) only row 0 Confirms f5*4 term
SFA[0, 100] = 2.0 row 0 changes (k=100) only row 0 Confirms tile-overflow range

Why the previous remap was broken

The previous code used cute::get<0>(flat) and cute::get<1>(flat) to extract (m, k). Since flatten produces (inner_m, sub_m, tile_m, ...) in order, get<0> and get<1> are both M sub-indices — they carry no K information. This caused only k_group=0 to work; all other K-groups were silently mapped to the wrong source offset.

Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than M * K_sf. Unmapped padding slots reading garbage caused sporadic wrong results.


Bugs Found & Fixed

1. unpack_ue4m3_u32: value cast vs bit reinterpret

File: nvfp4_mega_moe.py
Bug: (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) converts integer 63 → float8(63.0).
Fix: (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) reinterprets bit pattern 0x3F → float8(~0.984).
Also: uint32 lacks CUDA bitwise ops — cast to int32 first.
Impact: Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe.

2. stage_activation: three independent bugs

File: nvfp4_moe.py
Bug A: clamp(0, 15) zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag).
Bug B: Stored block_max but divided by block_max/6.0 → stored scale was 6× too large.
Bug C: Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2.
Fix: Rewrote with proper nearest-neighbor E2M1 quantization.
Impact: Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top.

3. _fold_global_scale: logical_widths branch

File: weight_transform.py
Bug: logical_widths=[3072, 3072] caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded.
Fix: Removed the logical_widths branch entirely. The else branch correctly broadcasts each expert's own (E, 1) global scale across (E, N, K//16).

4. L1 weight interleave removed (transpose still needed)

File: weight_transform.py
Bug: _interleave_l1_weights assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat [gate; up] along the output dim, and our CUTLASS kernel uses ClusterShape<1, 1, 1>.
Fix: Removed the interleave function. Weights still need a transpose from checkpoint layout (N, K_half) row-major to CUTLASS layout (K_half, N) column-major — this is standard row→column conversion, not interleaving. Both L1 and L2 weights and scales are transposed.

5. SF remap: idx2crd+flatten coordinate extraction

File: cutlass_nvfp4_gemm.cu
Bug: cute::flatten(coord) produces 8 sub-indices (flat_rank=8). get<0> and get<1> are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong.
Fix: Correct extraction: m = f0 + f1*32 + f2*128, k_sf = f4 + f5*4. Zero-init dest buffer before remap.
Diagnostic trail: Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula.

6. SiLU after summing expert paths (math error)

File: nvfp4_mega_moe.py
Bug: The old grouped GEMM collapsed expert outputs into a weighted sum, then applied SiLU+Mul on the sum. silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ. The nonlinearity must happen per-expert-path.
Fix: Slot-based dispatch — L1 GEMM returns per-slot output, SiLU+Mul applied per-slot, L2 GEMM per-slot, routing weights applied once at final index_add_ scatter.

7. Routing weights applied twice

File: cutlass_nvfp4_gemm/kernel.py
Bug: cutlass_grouped_nvfp4_gemm applied topk_weights in its scatter loop. Called for both L1 and L2, each expert's contribution was scaled by topk_weight².
Fix: GEMM returns per-slot results with no routing weights. Single y.index_add_(0, slot_token, slot_weight * l2_slots) at the end.

Diagnostic: constant-scale test (smoking gun for SF bugs)

When all scale factors are set to UE4M3(1.0):

  • Cosine = 1.0000, MSE = 0.19 (expected FP4 quantization noise)

With real (variable) scale factors and the broken remap:

  • Cosine = 0.83 → scales are misaligned, not fundamentally broken

After the fix with correct coordinate extraction:

  • Cosine = 1.0000, MSE = 0.0 → perfect match with dequantized reference

Build & Deploy (B200)

# On B200 host — CUTLASS must be cloned and mounted
cd /root/nvidia-meeting/deepseek-v4-quant/

# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
docker compose up -d

The CUTLASS extension builds inside the container during pip install of the nvfp4-megamoe-kernel package. It needs:

  • CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
  • CUTLASS headers at /root/cutlass/include/
  • CCCL headers at /usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/
  • Device with SM100 compute capability (B200)

Known Issues / TODO

  1. MoE dispatch is slow — Fixed. Slot-based index_add_ replaces the Python double loop over tokens×topk. Routing weights applied once at final scatter.

  2. stage_activation is Python — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging.

  3. SF remap allocates every call — Fixed. SFB weight scales are prepacked into CUTLASS layout once (lazy, cached per layer). Only SFA (activation scales) remapped dynamically.

  4. Per-expert GEMM dispatch is serial Python loop — The cutlass_grouped_nvfp4_gemm iterates over 48 experts in a Python for loop. Each iteration launches one CUTLASS GEMM. Could benefit from a true grouped GEMM kernel or CUDA-side expert dispatch.


Environment Variables

Variable Default Description
MEGA_MOE_STATIC 0 Set to 1 to skip MoE kernel entirely (return zeros)
MEGA_MOE_DEBUG 0 Set to 1 for verbose logging
SKIP_ATTENTION 0 Skip attention layers (debug)

Repos

  • Kernel: sweetapi.com/biondizzle/nvfp4-megamoe-kernel (branch: master)
  • Deployment: sweetapi.com/biondizzle/deepseek-v4-quant (branch: modelopt-nvfp4)
  • Local: ~/dev/nvfp4-megamoe-kernel/, ~/dev/deepseek-v4-quant/
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%