docs: add NVFP4 mega MoE kernel README
This commit is contained in:
96
README_NVFP4.md
Normal file
96
README_NVFP4.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# DeepGEMM NVFP4 Mega MoE Kernel
|
||||
|
||||
## Overview
|
||||
|
||||
This branch adds native NVFP4 support to DeepGEMM's mega MoE kernel. NVFP4 uses E2M1 packed weights with UE4M3 block scales (group_size=16), as opposed to MXFP4 which uses UE8M0 scales (group_size=32).
|
||||
|
||||
**HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.**
|
||||
|
||||
## What Changed
|
||||
|
||||
### CUDA Kernel (`sm100_fp8_nvfp4_mega_moe.cuh`)
|
||||
|
||||
| Parameter | MXFP4 (original) | NVFP4 (new) |
|
||||
|-----------|-----------------|-------------|
|
||||
| `kGranK` | 32 | 16 |
|
||||
| Scale factor type | `float_ue8m0_t` | `float_ue4m3_t` |
|
||||
| PTX instruction | `kind::mxf8f6f4.block_scale` | `kind::mxf4nvf4.block_scale.scale_vec::4X` |
|
||||
| SF vector size | block32 / 2X | block16 / 4X |
|
||||
| TMEM SF layout | 2 sub-columns per UMMA | 4 sub-columns per UMMA |
|
||||
| UTCCP col stride | `i * 4` | `i * 8` |
|
||||
| `kNumSFATmemCols` | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` |
|
||||
| `kNumSFBTmemCols` | `SF_BLOCK_N / 32` | `SF_BLOCK_N / 32 * 4` |
|
||||
| Activation SF format | UE8M0 (`>> 23`) | UE4M3 (float→e4m3 cast) |
|
||||
| `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` |
|
||||
| `recipe` | `(1, 1, 32)` | `(1, 1, 16)` |
|
||||
|
||||
### PTX Wrappers (`tcgen05.cuh`)
|
||||
|
||||
Added:
|
||||
- `SM100_MMA_MXF4NVF4_2x1SM_SS` — 2-CTA NVFP4 block-scaled MMA
|
||||
- `SM100_MMA_MXF4NVF4_SS` — 1-CTA NVFP4 block-scaled MMA
|
||||
|
||||
### Python API (`mega/__init__.py`)
|
||||
|
||||
- `fp8_nvfp4_mega_moe()` — main entry point
|
||||
- `transform_nvfp4_weights_for_mega_moe()` — converts checkpoint weights to kernel format
|
||||
- `_pack_nvfp4_sf_for_utccp()` — packs UE4M3 scales into int32 UTCCP layout
|
||||
|
||||
### C++ Bindings
|
||||
|
||||
- `csrc/apis/mega_nvfp4.hpp` — SymmBuffer with NVFP4 SF strides (K/16)
|
||||
- `csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp` — JIT kernel
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
NVFP4 Checkpoint Kernel Format
|
||||
┌─────────────────────┐ ┌────────────────────────┐
|
||||
│ weight: uint8 │───────────>│ uint8 (same, E2M1) │
|
||||
│ (E2M1, 2 per byte) │ │ packed, interleaved │
|
||||
├─────────────────────┤ ├────────────────────────┤
|
||||
│ weight_scale: │ │ int32 (UTCCP layout) │
|
||||
│ float8_e4m3fn │──pack────>│ 4 UE4M3 bytes → 1 i32 │
|
||||
│ (UE4M3, group=16) │ +UTCCP │ then 4x32 transpose │
|
||||
├─────────────────────┤ └────────────────────────┘
|
||||
│ weight_scale_2: │
|
||||
│ float32 (global) │──────> Applied at weight load time
|
||||
│ │ (multiply into block scales)
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
### Scale Factor Flow
|
||||
|
||||
```
|
||||
Checkpoint: weight_scale (UE4M3) × weight_scale_2 (FP32) = dequantized BF16
|
||||
|
||||
Kernel path:
|
||||
1. pre-scale: weight_scale_2 * weight_scale → float32 block scales
|
||||
2. pack to UE4M3: float32 → cutlass::float_e4m3_t → uint8 → pack 4→int32
|
||||
3. UTCCP transpose for TMA consumption
|
||||
4. Tensor Core reads UE4M3 scales via mxf4nvf4 instruction
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
### scale_format_ constraint
|
||||
The CUTLASS instruction descriptor has a single `scale_format_` bit (0=E4M3, 1=E8M0) that applies to BOTH A and B scale factors. This means both activation (SFA) and weight (SFB) scales must use the same format. The L1 epilogue outputs UE4M3 activation scales to match the NVFP4 weight scales.
|
||||
|
||||
### Weight scale_2 handling
|
||||
The NVFP4 checkpoint has a dual-level scaling scheme:
|
||||
- `weight_scale`: per-block UE4M3 (group_size=16)
|
||||
- `weight_scale_2`: per-tensor float32 global scale
|
||||
|
||||
The `weight_scale_2` must be multiplied into the block scales **before** packing for the kernel. This is done in `transform_nvfp4_weights_for_mega_moe()`.
|
||||
|
||||
## Remaining Work
|
||||
|
||||
- [ ] Test compilation on B200 (SM100)
|
||||
- [ ] Verify UTCCP 4X column stride (i*8) — may need adjustment based on TMEM layout diagrams
|
||||
- [ ] Verify SF packing: UE4M3 bytes → int32 layout matches what the UTCCP instruction expects
|
||||
- [ ] Verify the L1 epilogue UE4M3 conversion (float → e4m3 cast + sign bit clear)
|
||||
- [ ] Validate scale_format_ bit value: currently set by `make_instr_desc_block_scaled<float_ue4m3_t>` which sets scale_format_=0 (E4M3)
|
||||
- [ ] Verify kNumSFATmemCols and kNumSFBTmemCols calculations for 4X layout
|
||||
- [ ] Integration with vLLM DeepseekV4MegaMoEExperts class
|
||||
- [ ] Weight loading: map NVFP4 checkpoint params to DeepseekV4MegaMoEExperts
|
||||
- [ ] End-to-end quality test: compare NVFP4 mega_moe output vs FlashInfer FP4 MoE
|
||||
Reference in New Issue
Block a user