diff --git a/README_NVFP4.md b/README_NVFP4.md index 74ab0c3..a458fb0 100644 --- a/README_NVFP4.md +++ b/README_NVFP4.md @@ -16,7 +16,7 @@ The key requirement: target **`sm_100a`** (not `sm_100`). The `a` suffix enables block-scaled instructions including `mxf4nvf4`. Targeting plain `sm_100` will produce "Feature '.scale_vec::4X' not supported on .target 'sm_100f'" errors. -## Kernel Architecture (TARGET) +## Kernel Architecture ``` sm100_fp8_nvfp4_mega_moe_impl @@ -28,10 +28,77 @@ sm100_fp8_nvfp4_mega_moe_impl ├── kNumSFATmemCols = SF_BLOCK_M / 32 * 4 ├── kNumSFBTmemCols = SF_BLOCK_N / 32 * 4 ├── kNumSFUint32 = kHidden / 64 (4 UE4M3 per int32) -├── UE4M3 L1 epilogue (float → cutlass::float_e4m3_t cast, sign bit cleared) +├── UE4M3 L1 epilogue (float → e4m3 cast, sign bit cleared) └── recipe = (1, 1, 16) ``` +## Critical: mxf4nvf4 Requires FP4×FP4 (Packed) + +**The `mxf4nvf4` PTX instruction requires BOTH A and B to be FP4 (E2M1 packed, 2 per byte).** + +Unlike `mxf8f6f4` which reads FP4 from SMEM as if it were FP8 (1 byte/element, low nibble=value, +high nibble=ignored via `float_e2m1_unpacksmem_t`), `mxf4nvf4` reads FP4 **packed** from SMEM: +- 2 E2M1 values per byte (low nibble=first, high nibble=second) +- Advances K by 32 bytes per UMMA_K=64 atom (64 packed nibbles = 32 bytes) +- Byte stride of `BLOCK_K / 2` per K-row (not `BLOCK_K * sizeof(dtype_t)`) + +### SMEM Layout for mxf4nvf4 + +```cpp +using a_dtype_t = cutlass::float_e2m1_t; // packed, 4 bits/element +using b_dtype_t = cutlass::float_e2m1_t; + +static_assert(cutlass::sizeof_bits_v == 4, + "mxf4nvf4 requires packed FP4 (4 bits/element) in SMEM"); + +constexpr uint32_t kSwizzleAMode = BLOCK_K / 2; // 64 bytes (packed) +constexpr uint32_t kSwizzleBMode = BLOCK_K / 2; // 64 bytes (packed) +constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K / 2; // packed +constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K / 2; // packed +``` + +### UMMA Descriptors + +The `make_umma_desc` and `advance_umma_desc_lo` helpers use `sizeof(dtype_t)` for byte math. +Since `sizeof(float_e2m1_t) == 1` but the real stride is 4 bits/element, pass the effective +byte-stride parameters: + +```cpp +// Use BLOCK_K/2 as the template param and uint8_t as dtype for correct byte math +auto a_desc = make_umma_desc(...); +// Advance by UMMA_K/2 bytes per atom (64 packed elements = 32 bytes) +a_desc.lo = advance_umma_desc_lo(..., k * (UMMA_K / 2)); +``` + +### L1 Epilogue (Packed FP4 Output) + +The L1 epilogue writes packed E2M1 to SMEM using direct `st.shared.u16` (no STSM, no swizzle for v1): + +``` +Each lane quantizes 4 BF16 → 4 E2M1 nibbles → 2 bytes (uint16) +Lane mapping: row_in_atom = lane_idx / 4, col_pair = lane_idx % 4 +Row stride: L1_OUT_BLOCK_N / 2 bytes (packed) +kWarpBytesPerRow = L1_OUT_BLOCK_N / 8 +``` + +TMA store: `block_n / 4` inner dim, no swizzle (`CU_TENSOR_MAP_SWIZZLE_NONE` for v1). + +### Host-side TMA Descriptors + +All activation TMA descriptors use packed FP4 dimensions: +- K-dim: `hidden / 2` bytes (not `hidden` elements) +- Inner block: `block_k / 2` bytes +- Swizzle mode: `swizzle_acts_mode / 2` +- `fp4_unpacked_smem = false` for all activation descriptors +- L1 output: `block_n / 4` inner, swizzle=0 (no swizzle) + +### Pybind Buffer Sizing + +The `NVFP4SymmBuffer` uses packed byte counts: +- `fp4_token_layout = layout::Data(hidden / 2)` (packed bytes per token) +- `fp4_intermediate_token_layout = layout::Data(intermediate_hidden / 2)` +- Tensor shapes: `{M, hidden / 2}` for uint8 packed activations + ## Weight Transformation Pipeline ``` @@ -63,27 +130,12 @@ native UE4M3 scales with block16 grouping. | TMEM SF cols (SFA) | `SF_BLOCK_M / 32` | `SF_BLOCK_M / 32 * 4` | | UTCCP col stride | `i * 4` | `i * 8` | | `kNumSFUint32` | `kHidden / 128` | `kHidden / 64` | -| L1 epilogue | UE8M0 (`>> 23`) | UE4M3 (float→e4m3 cast) | +| SMEM dtype | `float_e2m1_unpacksmem_t` (1 byte/elem) | `float_e2m1_t` (packed, 4 bits/elem) | +| SMEM row stride | `BLOCK_K * sizeof(dtype_t)` = 128 | `BLOCK_K / 2` = 64 (packed) | +| UMMA_K | 32 | 64 | +| L1 epilogue | STSM + UE8M0 scales | st.shared.u16 + UE4M3 scales, no swizzle | | recipe | `(1, 1, 32)` | `(1, 1, 16)` | -## Critical Implementation Details - -### scale_format_ constraint -The CUTLASS instruction descriptor has a single `scale_format_` bit (0=E4M3, 1=E8M0) -that applies to BOTH A and B scale factors. For NVFP4 (E4M3), both activation (SFA) -and weight (SFB) scales must use UE4M3. The L1 epilogue outputs UE4M3 activation scales -(float → `cutlass::float_e4m3_t` with sign bit cleared). - -### Arch flag -The JIT compiler MUST target `sm_100a`, not `sm_100`. Without the `a` suffix, the -`mxf4nvf4` instruction is unavailable and compilation will fail with -"Feature '.scale_vec::4X' not supported on .target 'sm_100f'". - -### Weight scale_2 folding -The NVFP4 checkpoint has dual-level scaling: per-block UE4M3 + per-tensor float32. -The `weight_scale_2` must be folded into the block scales before packing: -`effective_scale = block_scale * global_scale`, then re-quantize to UE4M3. - ## Build History | Build | Error | Fix | @@ -101,14 +153,19 @@ The `weight_scale_2` must be folded into the block scales before packing: | 19 | kGranK still 16 in C++ binding | Should stay 16 — was wrongly changed to 32 | | 20 | `uint32 >> 23` fails | Cast to int32 first | | 22 | Garbled output | Fell back to mxf8f6f4 — should use mxf4nvf4 on sm_100a | +| 23–24 | transform_nvfp4 l1_weight_scale_2 | Added global scale folding to Python | +| 25 | Triton float8→uint8 cast | Manual FP32→E4M3 bit manipulation | +| 25 | Triton `__rpow__` | IEEE 754 bit-level 2^exp | +| 26–28 | Unpacked SMEM (wrong for mxf4nvf4) | Half-zeroed accumulators | +| 29–31 | Packed FP4 SMEM + STSM | Wrong: STSM writes 1 byte/elem, MMA reads 2/elem | +| **32** | **Full packed FP4 revert** | **float_e2m1_t, BLOCK_K/2 swizzle, UMMA desc fixes** | +| 33 | Syntax error in patch | Orphan `@triton.jit` decorator | +| 34 | Triton nested `tl.where` | Sum-of-comparisons for E2M1 quantization | +| 35 | Triton `constexpr[0]` indexing | `tl.split()` for E2M1 pair packing | ## Remaining Work -- [ ] Fix DeepGEMM JIT to target `sm_100a` instead of `sm_100` -- [ ] Add NVFP4 MMA kind enum to DeepGEMM runtime (not just MXFP8FP4 with NVFP4 hat) -- [ ] Revert to Build 17's `mxf4nvf4.scale_vec::4X` instruction (was correct, just wrong arch) -- [ ] Revert `kGranK` to 16, UE4M3 scales, block16 SF layout -- [ ] Add `get_sf_uttcp_aligned_block_sizes` branch for block16 layout -- [ ] Remove UE4M3→UE8M0 conversion and block16→block32 merge from Python -- [ ] Verify TMEM 4X layout (i*8 stride, 4 sub-columns) -- [ ] End-to-end quality test on B200 +- [ ] End-to-end quality test on B200 (Build 35+) +- [ ] Add 32B/64B swizzle to L1 epilogue for perf (v2) +- [ ] Optimize L1→L2 bandwidth with packed format +- [ ] Performance benchmarking vs standard FusedMoE path