Added detailed documentation of the packed FP4 architecture: - mxf4nvf4 reads packed (2 per byte), NOT unpacked like mxf8f6f4 - SMEM layout: float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc byte math - L1 epilogue: st.shared.u16, no swizzle, kWarpBytesPerRow - Host TMA: hidden/2 K-dim, block_k/2 inner, fp4_unpacked_smem=false - Build history through Build 35
7.9 KiB
DeepGEMM NVFP4 Mega MoE Kernel
Overview
A native NVFP4 mega MoE kernel for DeepGEMM that uses kind::mxf4nvf4.block_scale.scale_vec::4X
to consume NVFP4 weights (E2M1 + UE4M3 block scales, group_size=16) directly on B200 (SM100a).
HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.
SM100a (B200) Hardware Support
B200 (SM100a) DOES support kind::mxf4nvf4 with scale_vec::4X (block16, UE4M3 scales).
Documented in PTX ISA 8.7 (CUDA 12.8+), confirmed by NVIDIA/CUTLASS/Colfax.
The key requirement: target sm_100a (not sm_100). The a suffix enables the FP4
block-scaled instructions including mxf4nvf4. Targeting plain sm_100 will produce
"Feature '.scale_vec::4X' not supported on .target 'sm_100f'" errors.
Kernel Architecture
sm100_fp8_nvfp4_mega_moe_impl
├── kGranK = 16 (NVFP4 native block size)
├── kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction
├── float_ue4m3_t instruction descriptor
├── SF layout: scale_vec::4X, 4 TMEM sub-columns per UMMA atom
├── UTCCP copy: i*8 stride (4X layout, 8 TMEM cols per 128-element group)
├── kNumSFATmemCols = SF_BLOCK_M / 32 * 4
├── kNumSFBTmemCols = SF_BLOCK_N / 32 * 4
├── kNumSFUint32 = kHidden / 64 (4 UE4M3 per int32)
├── UE4M3 L1 epilogue (float → e4m3 cast, sign bit cleared)
└── recipe = (1, 1, 16)
Critical: mxf4nvf4 Requires FP4×FP4 (Packed)
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed, 2 per byte).
Unlike mxf8f6f4 which reads FP4 from SMEM as if it were FP8 (1 byte/element, low nibble=value,
high nibble=ignored via float_e2m1_unpacksmem_t), mxf4nvf4 reads FP4 packed from SMEM:
- 2 E2M1 values per byte (low nibble=first, high nibble=second)
- Advances K by 32 bytes per UMMA_K=64 atom (64 packed nibbles = 32 bytes)
- Byte stride of
BLOCK_K / 2per K-row (notBLOCK_K * sizeof(dtype_t))
SMEM Layout for mxf4nvf4
using a_dtype_t = cutlass::float_e2m1_t; // packed, 4 bits/element
using b_dtype_t = cutlass::float_e2m1_t;
static_assert(cutlass::sizeof_bits_v<a_dtype_t> == 4,
"mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");
constexpr uint32_t kSwizzleAMode = BLOCK_K / 2; // 64 bytes (packed)
constexpr uint32_t kSwizzleBMode = BLOCK_K / 2; // 64 bytes (packed)
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2; // packed
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2; // packed
UMMA Descriptors
The make_umma_desc and advance_umma_desc_lo helpers use sizeof(dtype_t) for byte math.
Since sizeof(float_e2m1_t) == 1 but the real stride is 4 bits/element, pass the effective
byte-stride parameters:
// Use BLOCK_K/2 as the template param and uint8_t as dtype for correct byte math
auto a_desc = make_umma_desc<K, LOAD_BLOCK_M, BLOCK_K/2, kSwizzleAMode, false, uint8_t>(...);
// Advance by UMMA_K/2 bytes per atom (64 packed elements = 32 bytes)
a_desc.lo = advance_umma_desc_lo<K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(..., k * (UMMA_K / 2));
L1 Epilogue (Packed FP4 Output)
The L1 epilogue writes packed E2M1 to SMEM using direct st.shared.u16 (no STSM, no swizzle for v1):
Each lane quantizes 4 BF16 → 4 E2M1 nibbles → 2 bytes (uint16)
Lane mapping: row_in_atom = lane_idx / 4, col_pair = lane_idx % 4
Row stride: L1_OUT_BLOCK_N / 2 bytes (packed)
kWarpBytesPerRow = L1_OUT_BLOCK_N / 8
TMA store: block_n / 4 inner dim, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE for v1).
Host-side TMA Descriptors
All activation TMA descriptors use packed FP4 dimensions:
- K-dim:
hidden / 2bytes (nothiddenelements) - Inner block:
block_k / 2bytes - Swizzle mode:
swizzle_acts_mode / 2 fp4_unpacked_smem = falsefor all activation descriptors- L1 output:
block_n / 4inner, swizzle=0 (no swizzle)
Pybind Buffer Sizing
The NVFP4SymmBuffer uses packed byte counts:
fp4_token_layout = layout::Data(hidden / 2)(packed bytes per token)fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2)- Tensor shapes:
{M, hidden / 2}for uint8 packed activations
Weight Transformation Pipeline
NVFP4 Checkpoint Kernel Format
┌─────────────────────┐ ┌────────────────────────┐
│ weight: uint8 │────────────────→│ int8 (E2M1, same) │
│ (E2M1, 2 per byte) │ .view(int8) │ packed, interleaved │
├─────────────────────┤ ├────────────────────────┤
│ weight_scale: │ 1. fold global │ int32 (TMA-aligned │
│ float8_e4m3fn │ 2. pack 4→i32 │ UTCCP layout, │
│ (UE4M3, group=16) │ 3. transpose │ gran_k=16) │
├─────────────────────┤ 4. TMA-align └────────────────────────┘
│ weight_scale_2: │
│ float32 (global) │──folded into block scales before packing
└─────────────────────┘
NO UE4M3→UE8M0 conversion. NO block16→block32 merge. The kernel consumes native UE4M3 scales with block16 grouping.
Key Differences from MXFP4 mega_moe
| Parameter | MXFP4 | NVFP4 (this kernel) |
|---|---|---|
kGranK |
32 | 16 |
| PTX instruction | mxf8f6f4.block_scale |
mxf4nvf4.block_scale.scale_vec::4X |
| Scale factor type | float_ue8m0_t |
float_ue4m3_t |
| SF vector size | block32 / 2X | block16 / 4X |
| TMEM SF cols (SFA) | SF_BLOCK_M / 32 |
SF_BLOCK_M / 32 * 4 |
| UTCCP col stride | i * 4 |
i * 8 |
kNumSFUint32 |
kHidden / 128 |
kHidden / 64 |
| SMEM dtype | float_e2m1_unpacksmem_t (1 byte/elem) |
float_e2m1_t (packed, 4 bits/elem) |
| SMEM row stride | BLOCK_K * sizeof(dtype_t) = 128 |
BLOCK_K / 2 = 64 (packed) |
| UMMA_K | 32 | 64 |
| L1 epilogue | STSM + UE8M0 scales | st.shared.u16 + UE4M3 scales, no swizzle |
| recipe | (1, 1, 32) |
(1, 1, 16) |
Build History
| Build | Error | Fix |
|---|---|---|
| 1–6 | Dockerfile/build issues | NVRTC symlink, CPATH, PYTHONPATH |
| 7 | kPackedFP4 type mismatch |
uint8→int8 view |
| 9 | SF stride assertion | MN-major layout + TMA alignment |
| 10 | transform_sf no gran_k=16 |
C++ fix |
| 11 | SF dtype float8_e4m3fn rejected | Pack UE4M3→int32 first |
| 12–14 | SF stride layout | Transpose to MN-major |
| 15 | SymmBuffer too small | NVFP4-specific SymmBuffer (2× SF) |
| 16 | ImportError |
Python wrapper |
| 17 | NVCC: scale_vec::4X not on sm_100f |
Wrong arch: need sm_100a |
| 18 | scale_vec::2X also failed |
Same — sm_100a required |
| 19 | kGranK still 16 in C++ binding | Should stay 16 — was wrongly changed to 32 |
| 20 | uint32 >> 23 fails |
Cast to int32 first |
| 22 | Garbled output | Fell back to mxf8f6f4 — should use mxf4nvf4 on sm_100a |
| 23–24 | transform_nvfp4 l1_weight_scale_2 | Added global scale folding to Python |
| 25 | Triton float8→uint8 cast | Manual FP32→E4M3 bit manipulation |
| 25 | Triton __rpow__ |
IEEE 754 bit-level 2^exp |
| 26–28 | Unpacked SMEM (wrong for mxf4nvf4) | Half-zeroed accumulators |
| 29–31 | Packed FP4 SMEM + STSM | Wrong: STSM writes 1 byte/elem, MMA reads 2/elem |
| 32 | Full packed FP4 revert | float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc fixes |
| 33 | Syntax error in patch | Orphan @triton.jit decorator |
| 34 | Triton nested tl.where |
Sum-of-comparisons for E2M1 quantization |
| 35 | Triton constexpr[0] indexing |
tl.split() for E2M1 pair packing |
Remaining Work
- End-to-end quality test on B200 (Build 35+)
- Add 32B/64B swizzle to L1 epilogue for perf (v2)
- Optimize L1→L2 bandwidth with packed format
- Performance benchmarking vs standard FusedMoE path