4.7 KiB
DeepGEMM NVFP4 Mega MoE Kernel
Overview
This branch adds native NVFP4 support to DeepGEMM's mega MoE kernel. NVFP4 uses E2M1 packed weights with UE4M3 block scales (group_size=16), as opposed to MXFP4 which uses UE8M0 scales (group_size=32).
HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.
What Changed
CUDA Kernel (sm100_fp8_nvfp4_mega_moe.cuh)
| Parameter | MXFP4 (original) | NVFP4 (new) |
|---|---|---|
kGranK |
32 | 16 |
| Scale factor type | float_ue8m0_t |
float_ue4m3_t |
| PTX instruction | kind::mxf8f6f4.block_scale |
kind::mxf4nvf4.block_scale.scale_vec::4X |
| SF vector size | block32 / 2X | block16 / 4X |
| TMEM SF layout | 2 sub-columns per UMMA | 4 sub-columns per UMMA |
| UTCCP col stride | i * 4 |
i * 8 |
kNumSFATmemCols |
SF_BLOCK_M / 32 |
SF_BLOCK_M / 32 * 4 |
kNumSFBTmemCols |
SF_BLOCK_N / 32 |
SF_BLOCK_N / 32 * 4 |
| Activation SF format | UE8M0 (>> 23) |
UE4M3 (float→e4m3 cast) |
kNumSFUint32 |
kHidden / 128 |
kHidden / 64 |
recipe |
(1, 1, 32) |
(1, 1, 16) |
PTX Wrappers (tcgen05.cuh)
Added:
SM100_MMA_MXF4NVF4_2x1SM_SS— 2-CTA NVFP4 block-scaled MMASM100_MMA_MXF4NVF4_SS— 1-CTA NVFP4 block-scaled MMA
Python API (mega/__init__.py)
fp8_nvfp4_mega_moe()— main entry pointtransform_nvfp4_weights_for_mega_moe()— converts checkpoint weights to kernel format_pack_nvfp4_sf_for_utccp()— packs UE4M3 scales into int32 UTCCP layout
C++ Bindings
csrc/apis/mega_nvfp4.hpp— SymmBuffer with NVFP4 SF strides (K/16)csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp— JIT kernel
Architecture
NVFP4 Checkpoint Kernel Format
┌─────────────────────┐ ┌────────────────────────┐
│ weight: uint8 │───────────>│ uint8 (same, E2M1) │
│ (E2M1, 2 per byte) │ │ packed, interleaved │
├─────────────────────┤ ├────────────────────────┤
│ weight_scale: │ │ int32 (UTCCP layout) │
│ float8_e4m3fn │──pack────>│ 4 UE4M3 bytes → 1 i32 │
│ (UE4M3, group=16) │ +UTCCP │ then 4x32 transpose │
├─────────────────────┤ └────────────────────────┘
│ weight_scale_2: │
│ float32 (global) │──────> Applied at weight load time
│ │ (multiply into block scales)
└─────────────────────┘
Scale Factor Flow
Checkpoint: weight_scale (UE4M3) × weight_scale_2 (FP32) = dequantized BF16
Kernel path:
1. pre-scale: weight_scale_2 * weight_scale → float32 block scales
2. pack to UE4M3: float32 → cutlass::float_e4m3_t → uint8 → pack 4→int32
3. UTCCP transpose for TMA consumption
4. Tensor Core reads UE4M3 scales via mxf4nvf4 instruction
Important Notes
scale_format_ constraint
The CUTLASS instruction descriptor has a single scale_format_ bit (0=E4M3, 1=E8M0) that applies to BOTH A and B scale factors. This means both activation (SFA) and weight (SFB) scales must use the same format. The L1 epilogue outputs UE4M3 activation scales to match the NVFP4 weight scales.
Weight scale_2 handling
The NVFP4 checkpoint has a dual-level scaling scheme:
weight_scale: per-block UE4M3 (group_size=16)weight_scale_2: per-tensor float32 global scale
The weight_scale_2 must be multiplied into the block scales before packing for the kernel. This is done in transform_nvfp4_weights_for_mega_moe().
Remaining Work
- Test compilation on B200 (SM100) — COMPILED
- Verify UTCCP 4X column stride (i*8)
- Verify SF packing: UE4M3 → int32 → TMA-aligned layout
- Add gran_k=16 to C++ transform_sf_into_required_layout
- Fix SF layout: must be MN-major (stride(-2)=1) with TMA-aligned stride
- Verify the L1 epilogue UE4M3 conversion
- Integration with vLLM DeepseekV4MegaMoEExperts — wired, debugging
- End-to-end quality test
Debugging Log
- Build 7: kPackedFP4 mismatch → uint8→int8 view
- Build 9: SF stride assertion → need MN-major layout + TMA alignment
- Build 10: transform_sf_into_required_layout doesn't support gran_k=16 → C++ fix
- Build 11: SF dtype mismatch (float8_e4m3fn → must pack to int32 first)
- Build 12-14: SF stride layout — transpose to MN-major before transform