Files
DeepGEMM/README_NVFP4.md

6.0 KiB
Raw Blame History

DeepGEMM NVFP4 Mega MoE Kernel

Overview

This branch adds native NVFP4 support to DeepGEMM's mega MoE kernel. NVFP4 uses E2M1 packed weights with UE4M3 block scales (group_size=16), as opposed to MXFP4 which uses UE8M0 scales (group_size=32).

HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.

What Changed

CUDA Kernel (sm100_fp8_nvfp4_mega_moe.cuh)

Parameter MXFP4 (original) NVFP4 (new)
kGranK 32 16
Scale factor type float_ue8m0_t float_ue4m3_t
PTX instruction kind::mxf8f6f4.block_scale kind::mxf4nvf4.block_scale.scale_vec::4X
SF vector size block32 / 2X block16 / 4X
TMEM SF layout 2 sub-columns per UMMA 4 sub-columns per UMMA
UTCCP col stride i * 4 i * 8
kNumSFATmemCols SF_BLOCK_M / 32 SF_BLOCK_M / 32 * 4
kNumSFBTmemCols SF_BLOCK_N / 32 SF_BLOCK_N / 32 * 4
Activation SF format UE8M0 (>> 23) UE4M3 (float→e4m3 cast)
kNumSFUint32 kHidden / 128 kHidden / 64
recipe (1, 1, 32) (1, 1, 16)

PTX Wrappers (tcgen05.cuh)

Added:

  • SM100_MMA_MXF4NVF4_2x1SM_SS — 2-CTA NVFP4 block-scaled MMA
  • SM100_MMA_MXF4NVF4_SS — 1-CTA NVFP4 block-scaled MMA

Python API (mega/__init__.py)

  • fp8_nvfp4_mega_moe() — main entry point
  • transform_nvfp4_weights_for_mega_moe() — converts checkpoint weights to kernel format
  • _pack_nvfp4_sf_for_utccp() — packs UE4M3 scales into int32 UTCCP layout

C++ Bindings

  • csrc/apis/mega_nvfp4.hpp — SymmBuffer with NVFP4 SF strides (K/16)
  • csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp — JIT kernel

Architecture

NVFP4 Checkpoint                    Kernel Format
┌─────────────────────┐            ┌────────────────────────┐
│ weight: uint8       │───────────>│ uint8 (same, E2M1)    │
│ (E2M1, 2 per byte)  │            │ packed, interleaved    │
├─────────────────────┤            ├────────────────────────┤
│ weight_scale:       │            │ int32 (UTCCP layout)   │
│ float8_e4m3fn       │──pack────>│ 4 UE4M3 bytes → 1 i32 │
│ (UE4M3, group=16)   │  +UTCCP   │ then 4x32 transpose    │
├─────────────────────┤            └────────────────────────┘
│ weight_scale_2:     │
│ float32 (global)    │──────> Applied at weight load time
│                     │        (multiply into block scales)
└─────────────────────┘

Scale Factor Flow

Checkpoint: weight_scale (UE4M3) × weight_scale_2 (FP32) = dequantized BF16

Kernel path:
  1. pre-scale: weight_scale_2 * weight_scale → float32 block scales
  2. pack to UE4M3: float32 → cutlass::float_e4m3_t → uint8 → pack 4→int32
  3. UTCCP transpose for TMA consumption
  4. Tensor Core reads UE4M3 scales via mxf4nvf4 instruction

Important Notes

scale_format_ constraint

The CUTLASS instruction descriptor has a single scale_format_ bit (0=E4M3, 1=E8M0) that applies to BOTH A and B scale factors. This means both activation (SFA) and weight (SFB) scales must use the same format. The L1 epilogue outputs UE4M3 activation scales to match the NVFP4 weight scales.

Weight scale_2 handling

The NVFP4 checkpoint has a dual-level scaling scheme:

  • weight_scale: per-block UE4M3 (group_size=16)
  • weight_scale_2: per-tensor float32 global scale

The weight_scale_2 must be multiplied into the block scales before packing for the kernel. This is done in transform_nvfp4_weights_for_mega_moe().

Remaining Work

  • Test compilation on B200 (SM100) — COMPILED
  • Verify UTCCP 4X column stride (i*8)
  • Verify SF packing: UE4M3 → int32 → TMA-aligned layout
  • Add gran_k=16 to C++ transform_sf_into_required_layout
  • Fix SF layout: must be MN-major (stride(-2)=1) with TMA-aligned stride
  • Verify the L1 epilogue UE4M3 conversion
  • Integration with vLLM DeepseekV4MegaMoEExperts — wired, debugging
  • End-to-end quality test

SM100 (B200) Hardware Constraint

CRITICAL: B200 (SM100) does NOT support kind::mxf4nvf4 (neither scale_vec::2X nor 4X). This instruction requires SM103 (B300) or SM120 (GB300). On SM100, the only FP4 block-scaled MMA is kind::mxf8f6f4.block_scale with UE8M0 scales (block32, group_size=32).

Strategy: Keep NVFP4 E2M1 weights (same as MXFP4), convert UE4M3 block scales to UE8M0 for hardware compatibility. Merge NVFP4 block16→block32 (max of adjacent pairs). This is a scale format adaptation, not a weight format conversion.

Parameter NVFP4 Checkpoint Kernel (SM100 Adapted)
Weight format E2M1 uint8 E2M1 uint8 (unchanged)
Block scale format UE4M3 (float8_e4m3fn) UE8M0 (uint8)
Block size 16 32 (merged)
Global scale float32 Folded in before UE4M3→UE8M0
PTX instruction N/A (requires SM103+) mxf8f6f4.block_scale

Debugging Log

  • Build 7: kPackedFP4 mismatch → uint8→int8 view
  • Build 9: SF stride assertion → need MN-major layout + TMA alignment
  • Build 10: transform_sf_into_required_layout doesn't support gran_k=16 → C++ fix
  • Build 11: SF dtype mismatch (float8_e4m3fn → must pack to int32 first)
  • Build 12-14: SF stride layout — transpose to MN-major before transform
  • Build 15: SymmBuffer too small (NVFP4 has 2x SF) → use NVFP4 SymmBuffer
  • Build 16: ImportError (deep_gemm.mega.nvfp4) → Python wrapper
  • Build 17: NVCC error: scale_vec::4X not supported on sm_100f
  • Build 18: NVCC error: scale_vec::2X ALSO not supported on sm_100f
  • Build 19: kGranK still 16 in C++ binding
  • Build 20: Use mxf8f6f4 (same as MXFP4) with UE4M0 conversion