Files
DeepGEMM/README_NVFP4.md
biondizzle e608a20dec docs: major README update — packed FP4 SMEM layout, L1 epilogue, TMA descriptors
Added detailed documentation of the packed FP4 architecture:
- mxf4nvf4 reads packed (2 per byte), NOT unpacked like mxf8f6f4
- SMEM layout: float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc byte math
- L1 epilogue: st.shared.u16, no swizzle, kWarpBytesPerRow
- Host TMA: hidden/2 K-dim, block_k/2 inner, fp4_unpacked_smem=false
- Build history through Build 35
2026-05-11 22:40:09 +00:00

172 lines
7.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# DeepGEMM NVFP4 Mega MoE Kernel
## Overview
A native NVFP4 mega MoE kernel for DeepGEMM that uses `kind::mxf4nvf4.block_scale.scale_vec::4X`
to consume NVFP4 weights (E2M1 + UE4M3 block scales, group_size=16) directly on B200 (SM100a).
**HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.**
## SM100a (B200) Hardware Support
**B200 (SM100a) DOES support `kind::mxf4nvf4` with `scale_vec::4X`** (block16, UE4M3 scales).
Documented in PTX ISA 8.7 (CUDA 12.8+), confirmed by NVIDIA/CUTLASS/Colfax.
The key requirement: target **`sm_100a`** (not `sm_100`). The `a` suffix enables the FP4
block-scaled instructions including `mxf4nvf4`. Targeting plain `sm_100` will produce
"Feature '.scale_vec::4X' not supported on .target 'sm_100f'" errors.
## Kernel Architecture
```
sm100_fp8_nvfp4_mega_moe_impl
├── kGranK = 16 (NVFP4 native block size)
├── kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction
├── float_ue4m3_t instruction descriptor
├── SF layout: scale_vec::4X, 4 TMEM sub-columns per UMMA atom
├── UTCCP copy: i*8 stride (4X layout, 8 TMEM cols per 128-element group)
├── kNumSFATmemCols = SF_BLOCK_M / 32 * 4
├── kNumSFBTmemCols = SF_BLOCK_N / 32 * 4
├── kNumSFUint32 = kHidden / 64 (4 UE4M3 per int32)
├── UE4M3 L1 epilogue (float → e4m3 cast, sign bit cleared)
└── recipe = (1, 1, 16)
```
## Critical: mxf4nvf4 Requires FP4×FP4 (Packed)
**The `mxf4nvf4` PTX instruction requires BOTH A and B to be FP4 (E2M1 packed, 2 per byte).**
Unlike `mxf8f6f4` which reads FP4 from SMEM as if it were FP8 (1 byte/element, low nibble=value,
high nibble=ignored via `float_e2m1_unpacksmem_t`), `mxf4nvf4` reads FP4 **packed** from SMEM:
- 2 E2M1 values per byte (low nibble=first, high nibble=second)
- Advances K by 32 bytes per UMMA_K=64 atom (64 packed nibbles = 32 bytes)
- Byte stride of `BLOCK_K / 2` per K-row (not `BLOCK_K * sizeof(dtype_t)`)
### SMEM Layout for mxf4nvf4
```cpp
using a_dtype_t = cutlass::float_e2m1_t; // packed, 4 bits/element
using b_dtype_t = cutlass::float_e2m1_t;
static_assert(cutlass::sizeof_bits_v<a_dtype_t> == 4,
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
constexpr uint32_t kSwizzleAMode = BLOCK_K / 2; // 64 bytes (packed)
constexpr uint32_t kSwizzleBMode = BLOCK_K / 2; // 64 bytes (packed)
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2; // packed
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2; // packed
```
### UMMA Descriptors
The `make_umma_desc` and `advance_umma_desc_lo` helpers use `sizeof(dtype_t)` for byte math.
Since `sizeof(float_e2m1_t) == 1` but the real stride is 4 bits/element, pass the effective
byte-stride parameters:
```cpp
// Use BLOCK_K/2 as the template param and uint8_t as dtype for correct byte math
auto a_desc = make_umma_desc<K, LOAD_BLOCK_M, BLOCK_K/2, kSwizzleAMode, false, uint8_t>(...);
// Advance by UMMA_K/2 bytes per atom (64 packed elements = 32 bytes)
a_desc.lo = advance_umma_desc_lo<K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(..., k * (UMMA_K / 2));
```
### L1 Epilogue (Packed FP4 Output)
The L1 epilogue writes packed E2M1 to SMEM using direct `st.shared.u16` (no STSM, no swizzle for v1):
```
Each lane quantizes 4 BF16 → 4 E2M1 nibbles → 2 bytes (uint16)
Lane mapping: row_in_atom = lane_idx / 4, col_pair = lane_idx % 4
Row stride: L1_OUT_BLOCK_N / 2 bytes (packed)
kWarpBytesPerRow = L1_OUT_BLOCK_N / 8
```
TMA store: `block_n / 4` inner dim, no swizzle (`CU_TENSOR_MAP_SWIZZLE_NONE` for v1).
### Host-side TMA Descriptors
All activation TMA descriptors use packed FP4 dimensions:
- K-dim: `hidden / 2` bytes (not `hidden` elements)
- Inner block: `block_k / 2` bytes
- Swizzle mode: `swizzle_acts_mode / 2`
- `fp4_unpacked_smem = false` for all activation descriptors
- L1 output: `block_n / 4` inner, swizzle=0 (no swizzle)
### Pybind Buffer Sizing
The `NVFP4SymmBuffer` uses packed byte counts:
- `fp4_token_layout = layout::Data(hidden / 2)` (packed bytes per token)
- `fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2)`
- Tensor shapes: `{M, hidden / 2}` for uint8 packed activations
## Weight Transformation Pipeline
```
NVFP4 Checkpoint Kernel Format
┌─────────────────────┐ ┌────────────────────────┐
│ weight: uint8 │────────────────→│ int8 (E2M1, same) │
│ (E2M1, 2 per byte) │ .view(int8) │ packed, interleaved │
├─────────────────────┤ ├────────────────────────┤
│ weight_scale: │ 1. fold global │ int32 (TMA-aligned │
│ float8_e4m3fn │ 2. pack 4→i32 │ UTCCP layout, │
│ (UE4M3, group=16) │ 3. transpose │ gran_k=16) │
├─────────────────────┤ 4. TMA-align └────────────────────────┘
│ weight_scale_2: │
│ float32 (global) │──folded into block scales before packing
└─────────────────────┘
```
**NO UE4M3→UE8M0 conversion. NO block16→block32 merge.** The kernel consumes
native UE4M3 scales with block16 grouping.
## Key Differences from MXFP4 mega_moe
| Parameter | MXFP4 | NVFP4 (this kernel) |
|-----------|-------|---------------------|
| `kGranK` | 32 | 16 |
| PTX instruction | `mxf8f6f4.block_scale` | `mxf4nvf4.block_scale.scale_vec::4X` |
| Scale factor type | `float_ue8m0_t` | `float_ue4m3_t` |
| SF vector size | block32 / 2X | block16 / 4X |
| TMEM SF cols (SFA) | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` |
| UTCCP col stride | `i * 4` | `i * 8` |
| `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` |
| SMEM dtype | `float_e2m1_unpacksmem_t` (1 byte/elem) | `float_e2m1_t` (packed, 4 bits/elem) |
| SMEM row stride | `BLOCK_K * sizeof(dtype_t)` = 128 | `BLOCK_K / 2` = 64 (packed) |
| UMMA_K | 32 | 64 |
| L1 epilogue | STSM + UE8M0 scales | st.shared.u16 + UE4M3 scales, no swizzle |
| recipe | `(1, 1, 32)` | `(1, 1, 16)` |
## Build History
| Build | Error | Fix |
|-------|-------|-----|
| 16 | Dockerfile/build issues | NVRTC symlink, CPATH, PYTHONPATH |
| 7 | `kPackedFP4` type mismatch | uint8→int8 view |
| 9 | SF stride assertion | MN-major layout + TMA alignment |
| 10 | `transform_sf` no gran_k=16 | C++ fix |
| 11 | SF dtype float8_e4m3fn rejected | Pack UE4M3→int32 first |
| 1214 | SF stride layout | Transpose to MN-major |
| 15 | SymmBuffer too small | NVFP4-specific SymmBuffer (2× SF) |
| 16 | `ImportError` | Python wrapper |
| **17** | **NVCC: `scale_vec::4X` not on sm_100f** | **Wrong arch: need `sm_100a`** |
| 18 | `scale_vec::2X` also failed | Same — `sm_100a` required |
| 19 | kGranK still 16 in C++ binding | Should stay 16 — was wrongly changed to 32 |
| 20 | `uint32 >> 23` fails | Cast to int32 first |
| 22 | Garbled output | Fell back to mxf8f6f4 — should use mxf4nvf4 on sm_100a |
| 2324 | transform_nvfp4 l1_weight_scale_2 | Added global scale folding to Python |
| 25 | Triton float8→uint8 cast | Manual FP32→E4M3 bit manipulation |
| 25 | Triton `__rpow__` | IEEE 754 bit-level 2^exp |
| 2628 | Unpacked SMEM (wrong for mxf4nvf4) | Half-zeroed accumulators |
| 2931 | Packed FP4 SMEM + STSM | Wrong: STSM writes 1 byte/elem, MMA reads 2/elem |
| **32** | **Full packed FP4 revert** | **float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc fixes** |
| 33 | Syntax error in patch | Orphan `@triton.jit` decorator |
| 34 | Triton nested `tl.where` | Sum-of-comparisons for E2M1 quantization |
| 35 | Triton `constexpr[0]` indexing | `tl.split()` for E2M1 pair packing |
## Remaining Work
- [ ] End-to-end quality test on B200 (Build 35+)
- [ ] Add 32B/64B swizzle to L1 epilogue for perf (v2)
- [ ] Optimize L1→L2 bandwidth with packed format
- [ ] Performance benchmarking vs standard FusedMoE path