Files
DeepGEMM/README_NVFP4.md

6.6 KiB
Raw Blame History

DeepGEMM NVFP4 Mega MoE Kernel

Overview

A native NVFP4 mega MoE kernel for DeepGEMM, built on the nvfp4-mega-moe branch. Adapts the existing sm100_fp8_fp4_mega_moe_impl to handle NVFP4 checkpoint weights (E2M1 + UE4M3 block scales, group_size=16).

HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.

SM100 (B200) Hardware Constraint

CRITICAL: B200 (SM100) does NOT support kind::mxf4nvf4 (neither scale_vec::2X nor 4X). This instruction requires SM103 (B300) or SM120 (GB300). On SM100, the only FP4 block-scaled MMA is kind::mxf8f6f4.block_scale with UE8M0 scales (block32, group_size=32).

Strategy: Keep NVFP4 E2M1 weights (same as MXFP4), convert UE4M3 block scales to UE8M0 for hardware compatibility. Merge NVFP4 block16→block32 (max of adjacent pairs). This is a scale format adaptation, not a weight format conversion.

Parameter NVFP4 Checkpoint Kernel (SM100 Adapted)
Weight format E2M1 uint8 E2M1 uint8 (unchanged)
Block scale format UE4M3 (float8_e4m3fn) UE8M0 (uint8)
Block size 16 32 (merged)
Global scale float32 Folded in before UE4M3→UE8M0
PTX instruction N/A (requires SM103+) mxf8f6f4.block_scale

Result: Server starts and serves, but output is garbled. The UE4M3→UE8M0 conversion loses 3 bits of mantissa precision per scale (8× precision loss), destroying output quality.

Kernel Changes from MXFP4

The kernel is functionally identical to the MXFP4 mega_moe. The only differences are in the weight transformation pipeline (Python) and the documentation/comments.

CUDA Kernel (sm100_fp8_nvfp4_mega_moe.cuh)

Identical to sm100_fp8_fp4_mega_moe.cuh in runtime behavior:

  • Same kGranK = 32, same float_ue8m0_t descriptor, same mxf8f6f4 instruction
  • Same TMEM layout (2X), UTCCP copy (i*4), SF counts
  • Same L1 epilogue (UE8M0, >> 23)

The renamed file exists for documentation and future SM103+ support.

Python API (deep_gemm/mega/__init__.py)

from deep_gemm.mega import (
    fp8_nvfp4_mega_moe,
    transform_nvfp4_weights_for_mega_moe,
    get_symm_buffer_for_nvfp4_mega_moe,
)

# Transform NVFP4 checkpoint weights for the kernel
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
    (w13_weight, w13_weight_scale),  # uint8, float8_e4m3fn
    (w2_weight, w2_weight_scale),
    l1_weight_scale_2=w13_weight_scale_2,  # float32 global scale
    l2_weight_scale_2=w2_weight_scale_2,
)
# l1_weights/l2_weights: (int8 weight, int32 TMA-aligned SF)

# Run the kernel
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer, ...)

Weight Transformation Pipeline

NVFP4 Checkpoint                         Kernel Format
┌─────────────────────┐                 ┌────────────────────────┐
│ weight: uint8       │────────────────→│ int8 (E2M1, same)     │
│ (E2M1, 2 per byte)  │  .view(int8)    │ packed, interleaved    │
├─────────────────────┤                 ├────────────────────────┤
│ weight_scale:       │ 1. fold global  │ int32 (TMA-aligned     │
│ float8_e4m3fn       │ 2. merge 16→32  │  UTCCP layout)         │
│ (UE4M3, group=16)   │ 3. UE4M3→UE8M0 │                        │
├─────────────────────┤ 4. pack 4→i32   └────────────────────────┘
│ weight_scale_2:     │ 5. transpose    (same as MXFP4 pipeline)
│ float32 (global)    │ 6. TMA-align
└─────────────────────┘

Key steps:

  1. Fold global scale: block_scale * global_scale → UE4M3 (in float32, re-quantize)
  2. Merge block16→block32: max(adjacent_pair) — max preserves magnitude
  3. UE4M3→UE8M0: float32 → (bits >> 23) & 0xFF → uint8 (extract IEEE 754 exponent)
  4. Pack: 4 uint8 UE8M0 values → 1 int32
  5. Transpose: MN-major layout
  6. TMA-align: transform_sf_into_required_layout(sf, mn, k, (1,32), num_experts)

C++ Changes

File Change
csrc/apis/mega_nvfp4.hpp NVFP4 SymmBuffer (SF stride K/32, K/128 packed)
csrc/apis/layout.hpp gran_k=32 support (already existed, just documented)
csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp JIT kernel header
csrc/python_api.cpp Register NVFP4 APIs

PTX Wrappers Added (tcgen05.cuh)

  • SM100_MMA_MXF4NVF4_2x1SM_SS — for future SM103+ support
  • SM100_MMA_MXF4NVF4_SS — single-CTA variant

Note: These are NOT used on SM100. The kernel uses SM100_MMA_MXF8F6F4_2x1SM_SS instead.

Debugging Log (Builds 122)

Build Error Fix
16 Dockerfile/build issues NVRTC symlink, CPATH, PYTHONPATH
7 kPackedFP4 type mismatch uint8→int8 view
9 SF stride assertion MN-major layout + TMA alignment
10 transform_sf no gran_k=16 C++ fix (later reverted to 32)
11 SF dtype float8_e4m3fn rejected Pack to int32 first
1214 SF stride layout Transpose to MN-major
15 SymmBuffer too small NVFP4-specific SymmBuffer
16 ImportError Python wrapper
17 NVCC: scale_vec::4X not on sm_100f Hardware limit
18 NVCC: scale_vec::2X also not on sm_100f Hardware limit
19 kGranK=16 in C++ binding → 32
20 uint32 >> 23 fails Cast to int32 first
22 Garbled output UE4M3→UE8M0 precision loss (unfixable on SM100)

Path Forward

For SM103+ (B300/GB300)

The SM100_MMA_MXF4NVF4 PTX wrappers are already in the code. On SM103+:

  1. Switch kernel to use mxf4nvf4.block_scale.scale_vec::4X (block16)
  2. Keep UE4M3 scales (no conversion to UE8M0)
  3. Update TMEM layout to 4X (i*8 stride, 4 sub-columns)
  4. This should produce correct output with full NVFP4 precision

For SM100 (B200)

The UE4M3→UE8M0 conversion is fundamentally lossy. Options:

  1. FlashInfer FP4 MoE — dequant NVFP4→BF16, use BF16 GEMM (avoids scale conversion)
  2. BF16 mega_moe — dequant in shared memory, use BF16 MMA
  3. Accept the precision loss — only viable if the model is robust to scale quantization

Integration with vLLM

The mega-moe-nvfp4 branch of deepseek-v4-quant wires this kernel into DeepseekV4MegaMoEExperts:

  • finalize_weights(): calls transform_nvfp4_weights_for_mega_moe()
  • forward(): calls fp8_nvfp4_mega_moe() with recipe=(1,1,32)
  • get_symm_buffer(): uses get_symm_buffer_for_nvfp4_mega_moe()