123 lines
6.0 KiB
Markdown
123 lines
6.0 KiB
Markdown
# DeepGEMM NVFP4 Mega MoE Kernel
|
||
|
||
## Overview
|
||
|
||
This branch adds native NVFP4 support to DeepGEMM's mega MoE kernel. NVFP4 uses E2M1 packed weights with UE4M3 block scales (group_size=16), as opposed to MXFP4 which uses UE8M0 scales (group_size=32).
|
||
|
||
**HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.**
|
||
|
||
## What Changed
|
||
|
||
### CUDA Kernel (`sm100_fp8_nvfp4_mega_moe.cuh`)
|
||
|
||
| Parameter | MXFP4 (original) | NVFP4 (new) |
|
||
|-----------|-----------------|-------------|
|
||
| `kGranK` | 32 | 16 |
|
||
| Scale factor type | `float_ue8m0_t` | `float_ue4m3_t` |
|
||
| PTX instruction | `kind::mxf8f6f4.block_scale` | `kind::mxf4nvf4.block_scale.scale_vec::4X` |
|
||
| SF vector size | block32 / 2X | block16 / 4X |
|
||
| TMEM SF layout | 2 sub-columns per UMMA | 4 sub-columns per UMMA |
|
||
| UTCCP col stride | `i * 4` | `i * 8` |
|
||
| `kNumSFATmemCols` | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` |
|
||
| `kNumSFBTmemCols` | `SF_BLOCK_N / 32` | `SF_BLOCK_N / 32 * 4` |
|
||
| Activation SF format | UE8M0 (`>> 23`) | UE4M3 (float→e4m3 cast) |
|
||
| `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` |
|
||
| `recipe` | `(1, 1, 32)` | `(1, 1, 16)` |
|
||
|
||
### PTX Wrappers (`tcgen05.cuh`)
|
||
|
||
Added:
|
||
- `SM100_MMA_MXF4NVF4_2x1SM_SS` — 2-CTA NVFP4 block-scaled MMA
|
||
- `SM100_MMA_MXF4NVF4_SS` — 1-CTA NVFP4 block-scaled MMA
|
||
|
||
### Python API (`mega/__init__.py`)
|
||
|
||
- `fp8_nvfp4_mega_moe()` — main entry point
|
||
- `transform_nvfp4_weights_for_mega_moe()` — converts checkpoint weights to kernel format
|
||
- `_pack_nvfp4_sf_for_utccp()` — packs UE4M3 scales into int32 UTCCP layout
|
||
|
||
### C++ Bindings
|
||
|
||
- `csrc/apis/mega_nvfp4.hpp` — SymmBuffer with NVFP4 SF strides (K/16)
|
||
- `csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp` — JIT kernel
|
||
|
||
## Architecture
|
||
|
||
```
|
||
NVFP4 Checkpoint Kernel Format
|
||
┌─────────────────────┐ ┌────────────────────────┐
|
||
│ weight: uint8 │───────────>│ uint8 (same, E2M1) │
|
||
│ (E2M1, 2 per byte) │ │ packed, interleaved │
|
||
├─────────────────────┤ ├────────────────────────┤
|
||
│ weight_scale: │ │ int32 (UTCCP layout) │
|
||
│ float8_e4m3fn │──pack────>│ 4 UE4M3 bytes → 1 i32 │
|
||
│ (UE4M3, group=16) │ +UTCCP │ then 4x32 transpose │
|
||
├─────────────────────┤ └────────────────────────┘
|
||
│ weight_scale_2: │
|
||
│ float32 (global) │──────> Applied at weight load time
|
||
│ │ (multiply into block scales)
|
||
└─────────────────────┘
|
||
```
|
||
|
||
### Scale Factor Flow
|
||
|
||
```
|
||
Checkpoint: weight_scale (UE4M3) × weight_scale_2 (FP32) = dequantized BF16
|
||
|
||
Kernel path:
|
||
1. pre-scale: weight_scale_2 * weight_scale → float32 block scales
|
||
2. pack to UE4M3: float32 → cutlass::float_e4m3_t → uint8 → pack 4→int32
|
||
3. UTCCP transpose for TMA consumption
|
||
4. Tensor Core reads UE4M3 scales via mxf4nvf4 instruction
|
||
```
|
||
|
||
## Important Notes
|
||
|
||
### scale_format_ constraint
|
||
The CUTLASS instruction descriptor has a single `scale_format_` bit (0=E4M3, 1=E8M0) that applies to BOTH A and B scale factors. This means both activation (SFA) and weight (SFB) scales must use the same format. The L1 epilogue outputs UE4M3 activation scales to match the NVFP4 weight scales.
|
||
|
||
### Weight scale_2 handling
|
||
The NVFP4 checkpoint has a dual-level scaling scheme:
|
||
- `weight_scale`: per-block UE4M3 (group_size=16)
|
||
- `weight_scale_2`: per-tensor float32 global scale
|
||
|
||
The `weight_scale_2` must be multiplied into the block scales **before** packing for the kernel. This is done in `transform_nvfp4_weights_for_mega_moe()`.
|
||
|
||
## Remaining Work
|
||
|
||
- [x] Test compilation on B200 (SM100) — **COMPILED**
|
||
- [ ] Verify UTCCP 4X column stride (i*8)
|
||
- [ ] Verify SF packing: UE4M3 → int32 → TMA-aligned layout
|
||
- [x] Add gran_k=16 to C++ transform_sf_into_required_layout
|
||
- [ ] Fix SF layout: must be MN-major (stride(-2)=1) with TMA-aligned stride
|
||
- [ ] Verify the L1 epilogue UE4M3 conversion
|
||
- [ ] Integration with vLLM DeepseekV4MegaMoEExperts — wired, debugging
|
||
- [ ] End-to-end quality test
|
||
|
||
### SM100 (B200) Hardware Constraint
|
||
|
||
**CRITICAL**: B200 (SM100) does NOT support `kind::mxf4nvf4` (neither `scale_vec::2X` nor `4X`). This instruction requires SM103 (B300) or SM120 (GB300). On SM100, the only FP4 block-scaled MMA is `kind::mxf8f6f4.block_scale` with UE8M0 scales (block32, group_size=32).
|
||
|
||
**Strategy**: Keep NVFP4 E2M1 weights (same as MXFP4), convert UE4M3 block scales to UE8M0 for hardware compatibility. Merge NVFP4 block16→block32 (max of adjacent pairs). This is a scale format adaptation, not a weight format conversion.
|
||
|
||
| Parameter | NVFP4 Checkpoint | Kernel (SM100 Adapted) |
|
||
|-----------|-----------------|----------------------|
|
||
| Weight format | E2M1 uint8 | E2M1 uint8 (unchanged) |
|
||
| Block scale format | UE4M3 (float8_e4m3fn) | UE8M0 (uint8) |
|
||
| Block size | 16 | 32 (merged) |
|
||
| Global scale | float32 | Folded in before UE4M3→UE8M0 |
|
||
| PTX instruction | N/A (requires SM103+) | mxf8f6f4.block_scale |
|
||
|
||
### Debugging Log
|
||
- Build 7: kPackedFP4 mismatch → uint8→int8 view
|
||
- Build 9: SF stride assertion → need MN-major layout + TMA alignment
|
||
- Build 10: transform_sf_into_required_layout doesn't support gran_k=16 → C++ fix
|
||
- Build 11: SF dtype mismatch (float8_e4m3fn → must pack to int32 first)
|
||
- Build 12-14: SF stride layout — transpose to MN-major before transform
|
||
- Build 15: SymmBuffer too small (NVFP4 has 2x SF) → use NVFP4 SymmBuffer
|
||
- Build 16: ImportError (deep_gemm.mega.nvfp4) → Python wrapper
|
||
- Build 17: NVCC error: scale_vec::4X not supported on sm_100f
|
||
- Build 18: NVCC error: scale_vec::2X ALSO not supported on sm_100f
|
||
- Build 19: kGranK still 16 in C++ binding
|
||
- Build 20: Use mxf8f6f4 (same as MXFP4) with UE4M0 conversion
|