# DeepGEMM NVFP4 Mega MoE Kernel ## Overview A native NVFP4 mega MoE kernel for DeepGEMM that uses `kind::mxf4nvf4.block_scale.scale_vec::4X` to consume NVFP4 weights (E2M1 + UE4M3 block scales, group_size=16) directly on B200 (SM100a). **HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.** ## SM100a (B200) Hardware Support **B200 (SM100a) DOES support `kind::mxf4nvf4` with `scale_vec::4X`** (block16, UE4M3 scales). Documented in PTX ISA 8.7 (CUDA 12.8+), confirmed by NVIDIA/CUTLASS/Colfax. The key requirement: target **`sm_100a`** (not `sm_100`). The `a` suffix enables the FP4 block-scaled instructions including `mxf4nvf4`. Targeting plain `sm_100` will produce "Feature '.scale_vec::4X' not supported on .target 'sm_100f'" errors. ## Kernel Architecture ``` sm100_fp8_nvfp4_mega_moe_impl ├── kGranK = 16 (NVFP4 native block size) ├── kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction ├── float_ue4m3_t instruction descriptor ├── SF layout: scale_vec::4X, 4 TMEM sub-columns per UMMA atom ├── UTCCP copy: i*8 stride (4X layout, 8 TMEM cols per 128-element group) ├── kNumSFATmemCols = SF_BLOCK_M / 32 * 4 ├── kNumSFBTmemCols = SF_BLOCK_N / 32 * 4 ├── kNumSFUint32 = kHidden / 64 (4 UE4M3 per int32) ├── UE4M3 L1 epilogue (float → e4m3 cast, sign bit cleared) └── recipe = (1, 1, 16) ``` ## Critical: mxf4nvf4 Requires FP4×FP4 (Packed) **The `mxf4nvf4` PTX instruction requires BOTH A and B to be FP4 (E2M1 packed, 2 per byte).** Unlike `mxf8f6f4` which reads FP4 from SMEM as if it were FP8 (1 byte/element, low nibble=value, high nibble=ignored via `float_e2m1_unpacksmem_t`), `mxf4nvf4` reads FP4 **packed** from SMEM: - 2 E2M1 values per byte (low nibble=first, high nibble=second) - Advances K by 32 bytes per UMMA_K=64 atom (64 packed nibbles = 32 bytes) - Byte stride of `BLOCK_K / 2` per K-row (not `BLOCK_K * sizeof(dtype_t)`) ### SMEM Layout for mxf4nvf4 ```cpp using a_dtype_t = cutlass::float_e2m1_t; // packed, 4 bits/element using b_dtype_t = cutlass::float_e2m1_t; static_assert(cutlass::sizeof_bits_v == 4, "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM"); constexpr uint32_t kSwizzleAMode = BLOCK_K / 2; // 64 bytes (packed) constexpr uint32_t kSwizzleBMode = BLOCK_K / 2; // 64 bytes (packed) constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2; // packed constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2; // packed ``` ### UMMA Descriptors The `make_umma_desc` and `advance_umma_desc_lo` helpers use `sizeof(dtype_t)` for byte math. Since `sizeof(float_e2m1_t) == 1` but the real stride is 4 bits/element, pass the effective byte-stride parameters: ```cpp // Use BLOCK_K/2 as the template param and uint8_t as dtype for correct byte math auto a_desc = make_umma_desc(...); // Advance by UMMA_K/2 bytes per atom (64 packed elements = 32 bytes) a_desc.lo = advance_umma_desc_lo(..., k * (UMMA_K / 2)); ``` ### L1 Epilogue (Packed FP4 Output) The L1 epilogue writes packed E2M1 to SMEM using direct `st.shared.u16` (no STSM, no swizzle for v1): ``` Each lane quantizes 4 BF16 → 4 E2M1 nibbles → 2 bytes (uint16) Lane mapping: row_in_atom = lane_idx / 4, col_pair = lane_idx % 4 Row stride: L1_OUT_BLOCK_N / 2 bytes (packed) kWarpBytesPerRow = L1_OUT_BLOCK_N / 8 ``` TMA store: `block_n / 4` inner dim, no swizzle (`CU_TENSOR_MAP_SWIZZLE_NONE` for v1). ### Host-side TMA Descriptors All activation TMA descriptors use packed FP4 dimensions: - K-dim: `hidden / 2` bytes (not `hidden` elements) - Inner block: `block_k / 2` bytes - Swizzle mode: `swizzle_acts_mode / 2` - `fp4_unpacked_smem = false` for all activation descriptors - L1 output: `block_n / 4` inner, swizzle=0 (no swizzle) ### Pybind Buffer Sizing The `NVFP4SymmBuffer` uses packed byte counts: - `fp4_token_layout = layout::Data(hidden / 2)` (packed bytes per token) - `fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2)` - Tensor shapes: `{M, hidden / 2}` for uint8 packed activations ## Weight Transformation Pipeline ``` NVFP4 Checkpoint Kernel Format ┌─────────────────────┐ ┌────────────────────────┐ │ weight: uint8 │────────────────→│ int8 (E2M1, same) │ │ (E2M1, 2 per byte) │ .view(int8) │ packed, interleaved │ ├─────────────────────┤ ├────────────────────────┤ │ weight_scale: │ 1. fold global │ int32 (TMA-aligned │ │ float8_e4m3fn │ 2. pack 4→i32 │ UTCCP layout, │ │ (UE4M3, group=16) │ 3. transpose │ gran_k=16) │ ├─────────────────────┤ 4. TMA-align └────────────────────────┘ │ weight_scale_2: │ │ float32 (global) │──folded into block scales before packing └─────────────────────┘ ``` **NO UE4M3→UE8M0 conversion. NO block16→block32 merge.** The kernel consumes native UE4M3 scales with block16 grouping. ## Key Differences from MXFP4 mega_moe | Parameter | MXFP4 | NVFP4 (this kernel) | |-----------|-------|---------------------| | `kGranK` | 32 | 16 | | PTX instruction | `mxf8f6f4.block_scale` | `mxf4nvf4.block_scale.scale_vec::4X` | | Scale factor type | `float_ue8m0_t` | `float_ue4m3_t` | | SF vector size | block32 / 2X | block16 / 4X | | TMEM SF cols (SFA) | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` | | UTCCP col stride | `i * 4` | `i * 8` | | `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` | | SMEM dtype | `float_e2m1_unpacksmem_t` (1 byte/elem) | `float_e2m1_t` (packed, 4 bits/elem) | | SMEM row stride | `BLOCK_K * sizeof(dtype_t)` = 128 | `BLOCK_K / 2` = 64 (packed) | | UMMA_K | 32 | 64 | | L1 epilogue | STSM + UE8M0 scales | st.shared.u16 + UE4M3 scales, no swizzle | | recipe | `(1, 1, 32)` | `(1, 1, 16)` | ## Build History | Build | Error | Fix | |-------|-------|-----| | 1–6 | Dockerfile/build issues | NVRTC symlink, CPATH, PYTHONPATH | | 7 | `kPackedFP4` type mismatch | uint8→int8 view | | 9 | SF stride assertion | MN-major layout + TMA alignment | | 10 | `transform_sf` no gran_k=16 | C++ fix | | 11 | SF dtype float8_e4m3fn rejected | Pack UE4M3→int32 first | | 12–14 | SF stride layout | Transpose to MN-major | | 15 | SymmBuffer too small | NVFP4-specific SymmBuffer (2× SF) | | 16 | `ImportError` | Python wrapper | | **17** | **NVCC: `scale_vec::4X` not on sm_100f** | **Wrong arch: need `sm_100a`** | | 18 | `scale_vec::2X` also failed | Same — `sm_100a` required | | 19 | kGranK still 16 in C++ binding | Should stay 16 — was wrongly changed to 32 | | 20 | `uint32 >> 23` fails | Cast to int32 first | | 22 | Garbled output | Fell back to mxf8f6f4 — should use mxf4nvf4 on sm_100a | | 23–24 | transform_nvfp4 l1_weight_scale_2 | Added global scale folding to Python | | 25 | Triton float8→uint8 cast | Manual FP32→E4M3 bit manipulation | | 25 | Triton `__rpow__` | IEEE 754 bit-level 2^exp | | 26–28 | Unpacked SMEM (wrong for mxf4nvf4) | Half-zeroed accumulators | | 29–31 | Packed FP4 SMEM + STSM | Wrong: STSM writes 1 byte/elem, MMA reads 2/elem | | **32** | **Full packed FP4 revert** | **float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc fixes** | | 33 | Syntax error in patch | Orphan `@triton.jit` decorator | | 34 | Triton nested `tl.where` | Sum-of-comparisons for E2M1 quantization | | 35 | Triton `constexpr[0]` indexing | `tl.split()` for E2M1 pair packing | ## Remaining Work - [ ] End-to-end quality test on B200 (Build 35+) - [ ] Add 32B/64B swizzle to L1 epilogue for perf (v2) - [ ] Optimize L1→L2 bandwidth with packed format - [ ] Performance benchmarking vs standard FusedMoE path