Files
DeepGEMM/README_NVFP4.md
biondizzle e608a20dec docs: major README update — packed FP4 SMEM layout, L1 epilogue, TMA descriptors
Added detailed documentation of the packed FP4 architecture:
- mxf4nvf4 reads packed (2 per byte), NOT unpacked like mxf8f6f4
- SMEM layout: float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc byte math
- L1 epilogue: st.shared.u16, no swizzle, kWarpBytesPerRow
- Host TMA: hidden/2 K-dim, block_k/2 inner, fp4_unpacked_smem=false
- Build history through Build 35
2026-05-11 22:40:09 +00:00

7.9 KiB
Raw Blame History

DeepGEMM NVFP4 Mega MoE Kernel

Overview

A native NVFP4 mega MoE kernel for DeepGEMM that uses kind::mxf4nvf4.block_scale.scale_vec::4X to consume NVFP4 weights (E2M1 + UE4M3 block scales, group_size=16) directly on B200 (SM100a).

HARD RULE: MoE experts stay in NVFP4. Never convert to MXFP4.

SM100a (B200) Hardware Support

B200 (SM100a) DOES support kind::mxf4nvf4 with scale_vec::4X (block16, UE4M3 scales). Documented in PTX ISA 8.7 (CUDA 12.8+), confirmed by NVIDIA/CUTLASS/Colfax.

The key requirement: target sm_100a (not sm_100). The a suffix enables the FP4 block-scaled instructions including mxf4nvf4. Targeting plain sm_100 will produce "Feature '.scale_vec::4X' not supported on .target 'sm_100f'" errors.

Kernel Architecture

sm100_fp8_nvfp4_mega_moe_impl
├── kGranK = 16 (NVFP4 native block size)
├── kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction
├── float_ue4m3_t instruction descriptor
├── SF layout: scale_vec::4X, 4 TMEM sub-columns per UMMA atom
├── UTCCP copy: i*8 stride (4X layout, 8 TMEM cols per 128-element group)
├── kNumSFATmemCols = SF_BLOCK_M / 32 * 4
├── kNumSFBTmemCols = SF_BLOCK_N / 32 * 4
├── kNumSFUint32 = kHidden / 64 (4 UE4M3 per int32)
├── UE4M3 L1 epilogue (float → e4m3 cast, sign bit cleared)
└── recipe = (1, 1, 16)

Critical: mxf4nvf4 Requires FP4×FP4 (Packed)

The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed, 2 per byte).

Unlike mxf8f6f4 which reads FP4 from SMEM as if it were FP8 (1 byte/element, low nibble=value, high nibble=ignored via float_e2m1_unpacksmem_t), mxf4nvf4 reads FP4 packed from SMEM:

  • 2 E2M1 values per byte (low nibble=first, high nibble=second)
  • Advances K by 32 bytes per UMMA_K=64 atom (64 packed nibbles = 32 bytes)
  • Byte stride of BLOCK_K / 2 per K-row (not BLOCK_K * sizeof(dtype_t))

SMEM Layout for mxf4nvf4

using a_dtype_t = cutlass::float_e2m1_t;  // packed, 4 bits/element
using b_dtype_t = cutlass::float_e2m1_t;

static_assert(cutlass::sizeof_bits_v<a_dtype_t> == 4,
    "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM");

constexpr uint32_t kSwizzleAMode = BLOCK_K / 2;  // 64 bytes (packed)
constexpr uint32_t kSwizzleBMode = BLOCK_K / 2;  // 64 bytes (packed)
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2;  // packed
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2;  // packed

UMMA Descriptors

The make_umma_desc and advance_umma_desc_lo helpers use sizeof(dtype_t) for byte math. Since sizeof(float_e2m1_t) == 1 but the real stride is 4 bits/element, pass the effective byte-stride parameters:

// Use BLOCK_K/2 as the template param and uint8_t as dtype for correct byte math
auto a_desc = make_umma_desc<K, LOAD_BLOCK_M, BLOCK_K/2, kSwizzleAMode, false, uint8_t>(...);
// Advance by UMMA_K/2 bytes per atom (64 packed elements = 32 bytes)
a_desc.lo = advance_umma_desc_lo<K, LOAD_BLOCK_M, kSwizzleAMode, uint8_t>(..., k * (UMMA_K / 2));

L1 Epilogue (Packed FP4 Output)

The L1 epilogue writes packed E2M1 to SMEM using direct st.shared.u16 (no STSM, no swizzle for v1):

Each lane quantizes 4 BF16 → 4 E2M1 nibbles → 2 bytes (uint16)
Lane mapping: row_in_atom = lane_idx / 4, col_pair = lane_idx % 4
Row stride: L1_OUT_BLOCK_N / 2 bytes (packed)
kWarpBytesPerRow = L1_OUT_BLOCK_N / 8

TMA store: block_n / 4 inner dim, no swizzle (CU_TENSOR_MAP_SWIZZLE_NONE for v1).

Host-side TMA Descriptors

All activation TMA descriptors use packed FP4 dimensions:

  • K-dim: hidden / 2 bytes (not hidden elements)
  • Inner block: block_k / 2 bytes
  • Swizzle mode: swizzle_acts_mode / 2
  • fp4_unpacked_smem = false for all activation descriptors
  • L1 output: block_n / 4 inner, swizzle=0 (no swizzle)

Pybind Buffer Sizing

The NVFP4SymmBuffer uses packed byte counts:

  • fp4_token_layout = layout::Data(hidden / 2) (packed bytes per token)
  • fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2)
  • Tensor shapes: {M, hidden / 2} for uint8 packed activations

Weight Transformation Pipeline

NVFP4 Checkpoint                         Kernel Format
┌─────────────────────┐                 ┌────────────────────────┐
│ weight: uint8       │────────────────→│ int8 (E2M1, same)     │
│ (E2M1, 2 per byte)  │  .view(int8)    │ packed, interleaved    │
├─────────────────────┤                 ├────────────────────────┤
│ weight_scale:       │ 1. fold global  │ int32 (TMA-aligned     │
│ float8_e4m3fn       │ 2. pack 4→i32   │  UTCCP layout,         │
│ (UE4M3, group=16)   │ 3. transpose    │  gran_k=16)            │
├─────────────────────┤ 4. TMA-align    └────────────────────────┘
│ weight_scale_2:     │
│ float32 (global)    │──folded into block scales before packing
└─────────────────────┘

NO UE4M3→UE8M0 conversion. NO block16→block32 merge. The kernel consumes native UE4M3 scales with block16 grouping.

Key Differences from MXFP4 mega_moe

Parameter MXFP4 NVFP4 (this kernel)
kGranK 32 16
PTX instruction mxf8f6f4.block_scale mxf4nvf4.block_scale.scale_vec::4X
Scale factor type float_ue8m0_t float_ue4m3_t
SF vector size block32 / 2X block16 / 4X
TMEM SF cols (SFA) SF_BLOCK_M / 32 SF_BLOCK_M / 32 * 4
UTCCP col stride i * 4 i * 8
kNumSFUint32 kHidden / 128 kHidden / 64
SMEM dtype float_e2m1_unpacksmem_t (1 byte/elem) float_e2m1_t (packed, 4 bits/elem)
SMEM row stride BLOCK_K * sizeof(dtype_t) = 128 BLOCK_K / 2 = 64 (packed)
UMMA_K 32 64
L1 epilogue STSM + UE8M0 scales st.shared.u16 + UE4M3 scales, no swizzle
recipe (1, 1, 32) (1, 1, 16)

Build History

Build Error Fix
16 Dockerfile/build issues NVRTC symlink, CPATH, PYTHONPATH
7 kPackedFP4 type mismatch uint8→int8 view
9 SF stride assertion MN-major layout + TMA alignment
10 transform_sf no gran_k=16 C++ fix
11 SF dtype float8_e4m3fn rejected Pack UE4M3→int32 first
1214 SF stride layout Transpose to MN-major
15 SymmBuffer too small NVFP4-specific SymmBuffer (2× SF)
16 ImportError Python wrapper
17 NVCC: scale_vec::4X not on sm_100f Wrong arch: need sm_100a
18 scale_vec::2X also failed Same — sm_100a required
19 kGranK still 16 in C++ binding Should stay 16 — was wrongly changed to 32
20 uint32 >> 23 fails Cast to int32 first
22 Garbled output Fell back to mxf8f6f4 — should use mxf4nvf4 on sm_100a
2324 transform_nvfp4 l1_weight_scale_2 Added global scale folding to Python
25 Triton float8→uint8 cast Manual FP32→E4M3 bit manipulation
25 Triton __rpow__ IEEE 754 bit-level 2^exp
2628 Unpacked SMEM (wrong for mxf4nvf4) Half-zeroed accumulators
2931 Packed FP4 SMEM + STSM Wrong: STSM writes 1 byte/elem, MMA reads 2/elem
32 Full packed FP4 revert float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc fixes
33 Syntax error in patch Orphan @triton.jit decorator
34 Triton nested tl.where Sum-of-comparisons for E2M1 quantization
35 Triton constexpr[0] indexing tl.split() for E2M1 pair packing

Remaining Work

  • End-to-end quality test on B200 (Build 35+)
  • Add 32B/64B swizzle to L1 epilogue for perf (v2)
  • Optimize L1→L2 bandwidth with packed format
  • Performance benchmarking vs standard FusedMoE path