Files
DeepGEMM/README_NVFP4.md

97 lines
4.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# DeepGEMM NVFP4 Mega MoE Kernel
## Overview
This branch adds native NVFP4 support to DeepGEMM's mega MoE kernel. NVFP4 uses E2M1 packed weights with UE4M3 block scales (group_size=16), as opposed to MXFP4 which uses UE8M0 scales (group_size=32).
**HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.**
## What Changed
### CUDA Kernel (`sm100_fp8_nvfp4_mega_moe.cuh`)
| Parameter | MXFP4 (original) | NVFP4 (new) |
|-----------|-----------------|-------------|
| `kGranK` | 32 | 16 |
| Scale factor type | `float_ue8m0_t` | `float_ue4m3_t` |
| PTX instruction | `kind::mxf8f6f4.block_scale` | `kind::mxf4nvf4.block_scale.scale_vec::4X` |
| SF vector size | block32 / 2X | block16 / 4X |
| TMEM SF layout | 2 sub-columns per UMMA | 4 sub-columns per UMMA |
| UTCCP col stride | `i * 4` | `i * 8` |
| `kNumSFATmemCols` | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` |
| `kNumSFBTmemCols` | `SF_BLOCK_N / 32` | `SF_BLOCK_N / 32 * 4` |
| Activation SF format | UE8M0 (`>> 23`) | UE4M3 (float→e4m3 cast) |
| `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` |
| `recipe` | `(1, 1, 32)` | `(1, 1, 16)` |
### PTX Wrappers (`tcgen05.cuh`)
Added:
- `SM100_MMA_MXF4NVF4_2x1SM_SS` — 2-CTA NVFP4 block-scaled MMA
- `SM100_MMA_MXF4NVF4_SS` — 1-CTA NVFP4 block-scaled MMA
### Python API (`mega/__init__.py`)
- `fp8_nvfp4_mega_moe()` — main entry point
- `transform_nvfp4_weights_for_mega_moe()` — converts checkpoint weights to kernel format
- `_pack_nvfp4_sf_for_utccp()` — packs UE4M3 scales into int32 UTCCP layout
### C++ Bindings
- `csrc/apis/mega_nvfp4.hpp` — SymmBuffer with NVFP4 SF strides (K/16)
- `csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp` — JIT kernel
## Architecture
```
NVFP4 Checkpoint Kernel Format
┌─────────────────────┐ ┌────────────────────────┐
│ weight: uint8 │───────────>│ uint8 (same, E2M1) │
│ (E2M1, 2 per byte) │ │ packed, interleaved │
├─────────────────────┤ ├────────────────────────┤
│ weight_scale: │ │ int32 (UTCCP layout) │
│ float8_e4m3fn │──pack────>│ 4 UE4M3 bytes → 1 i32 │
│ (UE4M3, group=16) │ +UTCCP │ then 4x32 transpose │
├─────────────────────┤ └────────────────────────┘
│ weight_scale_2: │
│ float32 (global) │──────> Applied at weight load time
│ │ (multiply into block scales)
└─────────────────────┘
```
### Scale Factor Flow
```
Checkpoint: weight_scale (UE4M3) × weight_scale_2 (FP32) = dequantized BF16
Kernel path:
1. pre-scale: weight_scale_2 * weight_scale → float32 block scales
2. pack to UE4M3: float32 → cutlass::float_e4m3_t → uint8 → pack 4→int32
3. UTCCP transpose for TMA consumption
4. Tensor Core reads UE4M3 scales via mxf4nvf4 instruction
```
## Important Notes
### scale_format_ constraint
The CUTLASS instruction descriptor has a single `scale_format_` bit (0=E4M3, 1=E8M0) that applies to BOTH A and B scale factors. This means both activation (SFA) and weight (SFB) scales must use the same format. The L1 epilogue outputs UE4M3 activation scales to match the NVFP4 weight scales.
### Weight scale_2 handling
The NVFP4 checkpoint has a dual-level scaling scheme:
- `weight_scale`: per-block UE4M3 (group_size=16)
- `weight_scale_2`: per-tensor float32 global scale
The `weight_scale_2` must be multiplied into the block scales **before** packing for the kernel. This is done in `transform_nvfp4_weights_for_mega_moe()`.
## Remaining Work
- [ ] Test compilation on B200 (SM100)
- [ ] Verify UTCCP 4X column stride (i*8) — may need adjustment based on TMEM layout diagrams
- [ ] Verify SF packing: UE4M3 bytes → int32 layout matches what the UTCCP instruction expects
- [ ] Verify the L1 epilogue UE4M3 conversion (float → e4m3 cast + sign bit clear)
- [ ] Validate scale_format_ bit value: currently set by `make_instr_desc_block_scaled<float_ue4m3_t>` which sets scale_format_=0 (E4M3)
- [ ] Verify kNumSFATmemCols and kNumSFBTmemCols calculations for 4X layout
- [ ] Integration with vLLM DeepseekV4MegaMoEExperts class
- [ ] Weight loading: map NVFP4 checkpoint params to DeepseekV4MegaMoEExperts
- [ ] End-to-end quality test: compare NVFP4 mega_moe output vs FlashInfer FP4 MoE