Added detailed SF remap section with the empirical coordinate dump table showing flat_rank=8 decomposition. Documented all 5 bugs found/fixed, the diagnostic trail (constant-scale test, single-element probes), and the 6 verification probes confirming the extraction formula.
15 KiB
nvfp4-megamoe-kernel
Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).
Replaces the broken fp8_nvfp4_mega_moe kernel from DeepGEMM with a working CUTLASS-based implementation that emits real SM100_MMA_MXF4_SS tensor core instructions.
Architecture
DeepSeek-V4-Pro is a 256-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 32 experts. For each token, the router picks the top-6 experts.
The MoE Forward Pass
Input hidden states (BF16)
│
▼
┌─────────────────┐
│ Shared Experts │ ← vLLM native FlashInfer CUTLASS NVFP4 path
│ (gate + up → │ (not our kernel)
│ SiLU * up → │
│ down) │
└─────────────────┘
│
▼
Staging Kernel (vLLM built-in)
BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
Writes to SymmBuffer.x / SymmBuffer.x_sf
│
▼
Router (vLLM built-in)
Writes topk_ids / topk_weights to SymmBuffer
│
▼
┌─────────────────────────────────────────┐
│ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py
│ │
│ 1. Read staged activation from buffer │
│ 2. L1 GEMM: gate_up_proj │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 output (6144-wide) │
│ 3. SiLU(gate) * up (activation) │
│ 4. stage_activation: BF16 → FP4 │ ← proper E2M1 quantization
│ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 output (7168-wide) │
│ 6. Write to output tensor │
└─────────────────────────────────────────┘
vLLM Startup Sequence (how our code plugs in)
1. vLLM engine init
└─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
└─ FlashInferCutlassNvFp4LinearKernel for linear layers
2. Model construction
└─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
Each layer has: attention + MoE block
MoE block has: shared experts + 256 routed experts
3. Weight loading
└─ 95 safetensor shards loaded
└─ weight, weight_scale, weight_scale_2 loaded per linear
4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN
└─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
└─ finalize_mega_moe_weights()
└─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
• Folds weight_scale_2 (global scale) into weight_scale (block scale)
• UE4M3 block-16 scales: 4 values packed per uint32
• Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank
5. SymmBuffer allocation
└─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe()
• Pre-allocates GPU buffers for:
- x: int8 packed E2M1 activations
- x_sf: uint32 packed UE4M3 activation scales
- topk_idx: int32 expert indices
- topk_weights: float32 routing weights
- buffer: BF16 all-reduce buffer
6. Profile run (warmup)
└─ First forward pass to allocate KV cache, etc.
└─ This is where the CUTLASS GEMM first executes
7. Ready to serve
File Map
nvfp4_megamoe_kernel/
├── __init__.py # Public API exports
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, nvfp4_mega_moe_l1/l2, stage_activation
├── weight_transform.py # Weight prep: fold global scale, pack UE4M3
├── symm_buffer.py # GPU buffer allocation for MoE dispatch
│
└── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel)
├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel
├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward)
├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop)
├── sf_layout.py # CUTLASS SF layout reference docs
├── setup.py # Build config (nvcc, CUTLASS include paths)
├── build.sh # Build script
├── test_gemm.py # Standalone test
└── README.md
What each file does (in call order)
| File | When it runs | What it does |
|---|---|---|
weight_transform.py |
Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales, packs UE4M3 into uint32. Output: ((l1_w, l1_sf), (l2_w, l2_sf)) |
symm_buffer.py |
Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. |
nvfp4_mega_moe.py |
Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains stage_activation (BF16→FP4 quantize for L1→L2) and unpack_ue4m3_u32 (uint32 packed scales → float8). |
cutlass_nvfp4_gemm/kernel.py |
Every forward pass (called by nvfp4_mega_moe) | Per-expert loop: gather tokens for each expert, call CUTLASS GEMM, scatter results with routing weights. |
cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu |
Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side scale factor remap (row-major → CUTLASS interleaved layout). |
cutlass_nvfp4_gemm/sf_layout.py |
Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). |
Data Formats
Weights
- Packed E2M1 (
int8): 2 FP4 values per byte. Shape:(E_per_rank, N, K//2), K-major layout. - UE4M3 block scales (
float8_e4m3fn): 1 scale per 16 FP4 values (group_size=16). Shape:(E_per_rank, N, K//16).
Activations (after staging kernel)
- Packed E2M1 (
int8): Shape:(num_tokens, K//2). - UE4M3 scales (
uint32): 4 UE4M3 values packed per uint32. Shape:(num_tokens, K//64).
GEMM dimensions (DeepSeek-V4-Pro)
- L1 (gate_up_proj): M×6144×7168 (per expert)
- L2 (down_proj): M×7168×3072 (per expert)
- 48 experts per rank (256 total / 8 ranks), top-6 routing
CUTLASS Scale Factor Remap
CUTLASS's Sm1xxBlockScaledConfig expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is:
Atom Shape: Shape<Shape<32, 4>, Shape<16, 4>>
Atom Stride: Stride<Stride<16, 4>, Stride<0, 1>>
Tiling: Step<_2, _1> (M tiled with step 2, K with step 1)
Our source data is row-major (M, K_sf) where K_sf = K / 16. The remap kernel (remap_sf_to_cutlass_kernel in cutlass_nvfp4_gemm.cu) converts from row-major to CUTLASS's interleaved layout.
How the remap works
The kernel iterates over CUTLASS destination indices, uses cute::idx2crd to get the hierarchical coordinate, then cute::flatten to get a flat tuple of 8 sub-indices. From those, we extract logical (m, k_sf) and read from the row-major source.
Flattened coordinate decomposition (flat_rank=8)
From the SfAtom layout with Step<_2, _1> tiling, flatten(idx2crd(idx, ...)) produces 8 values:
f0 = inner_m (0..31) — varies fastest within M atom
f1 = sub_m (0..3) — second M sub-coordinate
f2 = tile_m (0..) — M tile index
f3 = step_m stride — degenerate (always = sfa_size, not a coordinate)
f4 = sub_k (0..3) — K sub-coordinate within atom
f5 = tile_k (0..) — K tile index
f6 = 0 — unused
f7 = 0 — unused
Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256)
| idx | f0 | f1 | f2 | f3 | f4 | f5 | f6 | f7 |
|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0 | 0 | T | 0 | 0 | 0 | 0 |
| 1 | 0 | 0 | 0 | T | 1 | 0 | 0 | 0 |
| 4 | 0 | 1 | 0 | T | 0 | 0 | 0 | 0 |
| 16 | 1 | 0 | 0 | T | 0 | 0 | 0 | 0 |
| 511 | 31 | 3 | 0 | T | 3 | 0 | 0 | 0 |
| 512 | 0 | 0 | 0 | T | 0 | 1 | 0 | 0 |
| 1024 | 0 | 0 | 0 | T | 0 | 2 | 0 | 0 |
| 2048 | 0 | 0 | 0 | T | 0 | 4 | 0 | 0 |
| 4096 | 0 | 0 | 0 | T | 0 | 8 | 0 | 0 |
| 8192 | 0 | 0 | 0 | T | 0 | 16 | 0 | 0 |
| 65536 | 0 | 0 | 1 | T | 0 | 16 | 0 | 0 |
| 131072 | 0 | 0 | 2 | T | 0 | 32 | 0 | 0 |
Extraction formula
CuTe uses "first sub varies fastest" for Shape<32, 4>:
m = f0 + f1 * 32 + f2 * 128;
k_sf = f4 + f5 * 4;
This was verified with 6 independent probes:
| Probe | Source | Expected | Result |
|---|---|---|---|
| SFA[1, 0] = 2.0 | row 1 changes | ✅ only row 1 | Confirms f0 term |
| SFA[32, 0] = 2.0 | row 32 changes | ✅ only row 32 | Confirms f132, rules out f04+f1 |
| SFA[128, 0] = 2.0 | row 128 changes | ✅ only row 128 | Confirms f2*128 |
| SFA[0, 1] = 2.0 | row 0 changes (k=1) | ✅ only row 0 | Confirms f4 term |
| SFA[0, 4] = 2.0 | row 0 changes (k=4) | ✅ only row 0 | Confirms f5*4 term |
| SFA[0, 100] = 2.0 | row 0 changes (k=100) | ✅ only row 0 | Confirms tile-overflow range |
Why the previous remap was broken
The previous code used cute::get<0>(flat) and cute::get<1>(flat) to extract (m, k). Since flatten produces (inner_m, sub_m, tile_m, ...) in order, get<0> and get<1> are both M sub-indices — they carry no K information. This caused only k_group=0 to work; all other K-groups were silently mapped to the wrong source offset.
Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than M * K_sf. Unmapped padding slots reading garbage caused sporadic wrong results.
Bugs Found & Fixed
1. unpack_ue4m3_u32: value cast vs bit reinterpret
File: nvfp4_mega_moe.py
Bug: (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) converts integer 63 → float8(63.0).
Fix: (x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn) reinterprets bit pattern 0x3F → float8(~0.984).
Also: uint32 lacks CUDA bitwise ops — cast to int32 first.
Impact: Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe.
2. stage_activation: three independent bugs
File: nvfp4_moe.py
Bug A: clamp(0, 15) zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag).
Bug B: Stored block_max but divided by block_max/6.0 → stored scale was 6× too large.
Bug C: Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2.
Fix: Rewrote with proper nearest-neighbor E2M1 quantization.
Impact: Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top.
3. _fold_global_scale: logical_widths branch
File: weight_transform.py
Bug: logical_widths=[3072, 3072] caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded.
Fix: Removed the logical_widths branch entirely. The else branch correctly broadcasts each expert's own (E, 1) global scale across (E, N, K//16).
4. L1 weight interleave removed
File: weight_transform.py
Bug: _interleave_l1_weights assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat [gate; up] along the output dim, and our CUTLASS kernel uses ClusterShape<1, 1, 1>.
Fix: Removed entirely. l1_weight_out = l1_weight.contiguous().
5. SF remap: idx2crd+flatten coordinate extraction
File: cutlass_nvfp4_gemm.cu
Bug: cute::flatten(coord) produces 8 sub-indices (flat_rank=8). get<0> and get<1> are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong.
Fix: Correct extraction: m = f0 + f1*32 + f2*128, k_sf = f4 + f5*4. Zero-init dest buffer before remap.
Diagnostic trail: Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula.
Diagnostic: constant-scale test (smoking gun for SF bugs)
When all scale factors are set to UE4M3(1.0):
- Cosine = 1.0000, MSE = 0.19 (expected FP4 quantization noise)
With real (variable) scale factors and the broken remap:
- Cosine = 0.83 → scales are misaligned, not fundamentally broken
After the fix with correct coordinate extraction:
- Cosine = 1.0000, MSE = 0.0 → perfect match with dequantized reference
Build & Deploy (B200)
# On B200 host — CUTLASS must be cloned and mounted
cd /root/nvidia-meeting/deepseek-v4-quant/
# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
docker compose up -d
The CUTLASS extension builds inside the container during pip install of the nvfp4-megamoe-kernel package. It needs:
- CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
- CUTLASS headers at
/root/cutlass/include/ - CCCL headers at
/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/ - Device with SM100 compute capability (B200)
Known Issues
-
MoE dispatch is slow —
cutlass_grouped_nvfp4_gemmuses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch. -
stage_activation is Python — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging.
-
SF remap allocates every call —
cudaMemset+ remap kernel runs per GEMM invocation. Could pre-compute the CUTLASS-layout buffer once during weight transform.
Environment Variables
| Variable | Default | Description |
|---|---|---|
MEGA_MOE_STATIC |
0 | Set to 1 to skip MoE kernel entirely (return zeros) |
MEGA_MOE_DEBUG |
0 | Set to 1 for verbose logging |
SKIP_ATTENTION |
0 | Skip attention layers (debug) |
Repos
- Kernel:
sweetapi.com/biondizzle/nvfp4-megamoe-kernel(branch: master) - Deployment:
sweetapi.com/biondizzle/deepseek-v4-quant(branch: modelopt-nvfp4) - Local:
~/dev/nvfp4-mojo/,~/dev/deepseek-v4-quant/