6.6 KiB
DeepGEMM NVFP4 Mega MoE Kernel
Overview
A native NVFP4 mega MoE kernel for DeepGEMM, built on the nvfp4-mega-moe branch.
Adapts the existing sm100_fp8_fp4_mega_moe_impl to handle NVFP4 checkpoint weights
(E2M1 + UE4M3 block scales, group_size=16).
HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.
SM100 (B200) Hardware Constraint
CRITICAL: B200 (SM100) does NOT support kind::mxf4nvf4 (neither scale_vec::2X nor 4X).
This instruction requires SM103 (B300) or SM120 (GB300). On SM100, the only FP4 block-scaled
MMA is kind::mxf8f6f4.block_scale with UE8M0 scales (block32, group_size=32).
Strategy: Keep NVFP4 E2M1 weights (same as MXFP4), convert UE4M3 block scales to UE8M0 for hardware compatibility. Merge NVFP4 block16→block32 (max of adjacent pairs). This is a scale format adaptation, not a weight format conversion.
| Parameter | NVFP4 Checkpoint | Kernel (SM100 Adapted) |
|---|---|---|
| Weight format | E2M1 uint8 | E2M1 uint8 (unchanged) |
| Block scale format | UE4M3 (float8_e4m3fn) | UE8M0 (uint8) |
| Block size | 16 | 32 (merged) |
| Global scale | float32 | Folded in before UE4M3→UE8M0 |
| PTX instruction | N/A (requires SM103+) | mxf8f6f4.block_scale |
Result: Server starts and serves, but output is garbled. The UE4M3→UE8M0 conversion loses 3 bits of mantissa precision per scale (8× precision loss), destroying output quality.
Kernel Changes from MXFP4
The kernel is functionally identical to the MXFP4 mega_moe. The only differences are in the weight transformation pipeline (Python) and the documentation/comments.
CUDA Kernel (sm100_fp8_nvfp4_mega_moe.cuh)
Identical to sm100_fp8_fp4_mega_moe.cuh in runtime behavior:
- Same
kGranK = 32, samefloat_ue8m0_tdescriptor, samemxf8f6f4instruction - Same TMEM layout (2X), UTCCP copy (i*4), SF counts
- Same L1 epilogue (UE8M0,
>> 23)
The renamed file exists for documentation and future SM103+ support.
Python API (deep_gemm/mega/__init__.py)
from deep_gemm.mega import (
fp8_nvfp4_mega_moe,
transform_nvfp4_weights_for_mega_moe,
get_symm_buffer_for_nvfp4_mega_moe,
)
# Transform NVFP4 checkpoint weights for the kernel
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
(w13_weight, w13_weight_scale), # uint8, float8_e4m3fn
(w2_weight, w2_weight_scale),
l1_weight_scale_2=w13_weight_scale_2, # float32 global scale
l2_weight_scale_2=w2_weight_scale_2,
)
# l1_weights/l2_weights: (int8 weight, int32 TMA-aligned SF)
# Run the kernel
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer, ...)
Weight Transformation Pipeline
NVFP4 Checkpoint Kernel Format
┌─────────────────────┐ ┌────────────────────────┐
│ weight: uint8 │────────────────→│ int8 (E2M1, same) │
│ (E2M1, 2 per byte) │ .view(int8) │ packed, interleaved │
├─────────────────────┤ ├────────────────────────┤
│ weight_scale: │ 1. fold global │ int32 (TMA-aligned │
│ float8_e4m3fn │ 2. merge 16→32 │ UTCCP layout) │
│ (UE4M3, group=16) │ 3. UE4M3→UE8M0 │ │
├─────────────────────┤ 4. pack 4→i32 └────────────────────────┘
│ weight_scale_2: │ 5. transpose (same as MXFP4 pipeline)
│ float32 (global) │ 6. TMA-align
└─────────────────────┘
Key steps:
- Fold global scale:
block_scale * global_scale → UE4M3(in float32, re-quantize) - Merge block16→block32:
max(adjacent_pair)— max preserves magnitude - UE4M3→UE8M0:
float32 → (bits >> 23) & 0xFF → uint8(extract IEEE 754 exponent) - Pack: 4 uint8 UE8M0 values → 1 int32
- Transpose: MN-major layout
- TMA-align:
transform_sf_into_required_layout(sf, mn, k, (1,32), num_experts)
C++ Changes
| File | Change |
|---|---|
csrc/apis/mega_nvfp4.hpp |
NVFP4 SymmBuffer (SF stride K/32, K/128 packed) |
csrc/apis/layout.hpp |
gran_k=32 support (already existed, just documented) |
csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp |
JIT kernel header |
csrc/python_api.cpp |
Register NVFP4 APIs |
PTX Wrappers Added (tcgen05.cuh)
SM100_MMA_MXF4NVF4_2x1SM_SS— for future SM103+ supportSM100_MMA_MXF4NVF4_SS— single-CTA variant
Note: These are NOT used on SM100. The kernel uses SM100_MMA_MXF8F6F4_2x1SM_SS instead.
Debugging Log (Builds 1–22)
| Build | Error | Fix |
|---|---|---|
| 1–6 | Dockerfile/build issues | NVRTC symlink, CPATH, PYTHONPATH |
| 7 | kPackedFP4 type mismatch |
uint8→int8 view |
| 9 | SF stride assertion | MN-major layout + TMA alignment |
| 10 | transform_sf no gran_k=16 |
C++ fix (later reverted to 32) |
| 11 | SF dtype float8_e4m3fn rejected | Pack to int32 first |
| 12–14 | SF stride layout | Transpose to MN-major |
| 15 | SymmBuffer too small | NVFP4-specific SymmBuffer |
| 16 | ImportError |
Python wrapper |
| 17 | NVCC: scale_vec::4X not on sm_100f |
Hardware limit |
| 18 | NVCC: scale_vec::2X also not on sm_100f |
Hardware limit |
| 19 | kGranK=16 in C++ binding | → 32 |
| 20 | uint32 >> 23 fails |
Cast to int32 first |
| 22 | Garbled output | UE4M3→UE8M0 precision loss (unfixable on SM100) |
Path Forward
For SM103+ (B300/GB300)
The SM100_MMA_MXF4NVF4 PTX wrappers are already in the code. On SM103+:
- Switch kernel to use
mxf4nvf4.block_scale.scale_vec::4X(block16) - Keep UE4M3 scales (no conversion to UE8M0)
- Update TMEM layout to 4X (i*8 stride, 4 sub-columns)
- This should produce correct output with full NVFP4 precision
For SM100 (B200)
The UE4M3→UE8M0 conversion is fundamentally lossy. Options:
- FlashInfer FP4 MoE — dequant NVFP4→BF16, use BF16 GEMM (avoids scale conversion)
- BF16 mega_moe — dequant in shared memory, use BF16 MMA
- Accept the precision loss — only viable if the model is robust to scale quantization
Integration with vLLM
The mega-moe-nvfp4 branch of deepseek-v4-quant wires this kernel into
DeepseekV4MegaMoEExperts:
finalize_weights(): callstransform_nvfp4_weights_for_mega_moe()forward(): callsfp8_nvfp4_mega_moe()withrecipe=(1,1,32)get_symm_buffer(): usesget_symm_buffer_for_nvfp4_mega_moe()