diff --git a/README_NVFP4.md b/README_NVFP4.md new file mode 100644 index 0000000..620b757 --- /dev/null +++ b/README_NVFP4.md @@ -0,0 +1,96 @@ +# DeepGEMM NVFP4 Mega MoE Kernel + +## Overview + +This branch adds native NVFP4 support to DeepGEMM's mega MoE kernel. NVFP4 uses E2M1 packed weights with UE4M3 block scales (group_size=16), as opposed to MXFP4 which uses UE8M0 scales (group_size=32). + +**HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.** + +## What Changed + +### CUDA Kernel (`sm100_fp8_nvfp4_mega_moe.cuh`) + +| Parameter | MXFP4 (original) | NVFP4 (new) | +|-----------|-----------------|-------------| +| `kGranK` | 32 | 16 | +| Scale factor type | `float_ue8m0_t` | `float_ue4m3_t` | +| PTX instruction | `kind::mxf8f6f4.block_scale` | `kind::mxf4nvf4.block_scale.scale_vec::4X` | +| SF vector size | block32 / 2X | block16 / 4X | +| TMEM SF layout | 2 sub-columns per UMMA | 4 sub-columns per UMMA | +| UTCCP col stride | `i * 4` | `i * 8` | +| `kNumSFATmemCols` | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` | +| `kNumSFBTmemCols` | `SF_BLOCK_N / 32` | `SF_BLOCK_N / 32 * 4` | +| Activation SF format | UE8M0 (`>> 23`) | UE4M3 (float→e4m3 cast) | +| `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` | +| `recipe` | `(1, 1, 32)` | `(1, 1, 16)` | + +### PTX Wrappers (`tcgen05.cuh`) + +Added: +- `SM100_MMA_MXF4NVF4_2x1SM_SS` — 2-CTA NVFP4 block-scaled MMA +- `SM100_MMA_MXF4NVF4_SS` — 1-CTA NVFP4 block-scaled MMA + +### Python API (`mega/__init__.py`) + +- `fp8_nvfp4_mega_moe()` — main entry point +- `transform_nvfp4_weights_for_mega_moe()` — converts checkpoint weights to kernel format +- `_pack_nvfp4_sf_for_utccp()` — packs UE4M3 scales into int32 UTCCP layout + +### C++ Bindings + +- `csrc/apis/mega_nvfp4.hpp` — SymmBuffer with NVFP4 SF strides (K/16) +- `csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp` — JIT kernel + +## Architecture + +``` +NVFP4 Checkpoint Kernel Format +┌─────────────────────┐ ┌────────────────────────┐ +│ weight: uint8 │───────────>│ uint8 (same, E2M1) │ +│ (E2M1, 2 per byte) │ │ packed, interleaved │ +├─────────────────────┤ ├────────────────────────┤ +│ weight_scale: │ │ int32 (UTCCP layout) │ +│ float8_e4m3fn │──pack────>│ 4 UE4M3 bytes → 1 i32 │ +│ (UE4M3, group=16) │ +UTCCP │ then 4x32 transpose │ +├─────────────────────┤ └────────────────────────┘ +│ weight_scale_2: │ +│ float32 (global) │──────> Applied at weight load time +│ │ (multiply into block scales) +└─────────────────────┘ +``` + +### Scale Factor Flow + +``` +Checkpoint: weight_scale (UE4M3) × weight_scale_2 (FP32) = dequantized BF16 + +Kernel path: + 1. pre-scale: weight_scale_2 * weight_scale → float32 block scales + 2. pack to UE4M3: float32 → cutlass::float_e4m3_t → uint8 → pack 4→int32 + 3. UTCCP transpose for TMA consumption + 4. Tensor Core reads UE4M3 scales via mxf4nvf4 instruction +``` + +## Important Notes + +### scale_format_ constraint +The CUTLASS instruction descriptor has a single `scale_format_` bit (0=E4M3, 1=E8M0) that applies to BOTH A and B scale factors. This means both activation (SFA) and weight (SFB) scales must use the same format. The L1 epilogue outputs UE4M3 activation scales to match the NVFP4 weight scales. + +### Weight scale_2 handling +The NVFP4 checkpoint has a dual-level scaling scheme: +- `weight_scale`: per-block UE4M3 (group_size=16) +- `weight_scale_2`: per-tensor float32 global scale + +The `weight_scale_2` must be multiplied into the block scales **before** packing for the kernel. This is done in `transform_nvfp4_weights_for_mega_moe()`. + +## Remaining Work + +- [ ] Test compilation on B200 (SM100) +- [ ] Verify UTCCP 4X column stride (i*8) — may need adjustment based on TMEM layout diagrams +- [ ] Verify SF packing: UE4M3 bytes → int32 layout matches what the UTCCP instruction expects +- [ ] Verify the L1 epilogue UE4M3 conversion (float → e4m3 cast + sign bit clear) +- [ ] Validate scale_format_ bit value: currently set by `make_instr_desc_block_scaled` which sets scale_format_=0 (E4M3) +- [ ] Verify kNumSFATmemCols and kNumSFBTmemCols calculations for 4X layout +- [ ] Integration with vLLM DeepseekV4MegaMoEExperts class +- [ ] Weight loading: map NVFP4 checkpoint params to DeepseekV4MegaMoEExperts +- [ ] End-to-end quality test: compare NVFP4 mega_moe output vs FlashInfer FP4 MoE