diff --git a/README.md b/README.md index 93524a59..c8080d0b 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,10 @@ Input hidden states (BF16) │ ▼ ┌─────────────────┐ -│ Shared Experts │ ← BYPASSED (returning zeros — FlashInfer TF32 GEMM crashes) -│ (FlashInfer │ -│ CUTLASS) │ +│ Shared Experts │ ← vLLM native FlashInfer CUTLASS NVFP4 path +│ (gate + up → │ (not our kernel) +│ SiLU * up → │ +│ down) │ └─────────────────┘ │ ▼ @@ -40,8 +41,8 @@ Input hidden states (BF16) │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 output (6144-wide) │ │ 3. SiLU(gate) * up (activation) │ -│ 4. stage_activation: BF16 → FP4 │ ← simple absmax quantize (needs work) -│ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled +│ 4. stage_activation: BF16 → FP4 │ ← proper E2M1 quantization +│ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled │ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX │ → BF16 output (7168-wide) │ │ 6. Write to output tensor │ @@ -70,11 +71,10 @@ Input hidden states (BF16) └─ weight_transform.py: transform_nvfp4_weights_for_mega_moe() • Folds weight_scale_2 (global scale) into weight_scale (block scale) • UE4M3 block-16 scales: 4 values packed per uint32 - • Interleaves L1 (gate_up) weights for 2CTA UMMA • Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank 5. SymmBuffer allocation - └─ symm_buffer.py: get_symm_buffer_for_nvfp4 mega_moe() + └─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe() • Pre-allocates GPU buffers for: - x: int8 packed E2M1 activations - x_sf: uint32 packed UE4M3 activation scales @@ -97,17 +97,17 @@ Input hidden states (BF16) nvfp4_megamoe_kernel/ ├── __init__.py # Public API exports ├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, nvfp4_mega_moe_l1/l2, stage_activation -├── weight_transform.py # Weight prep: fold global scale, pack UE4M3, interleave L1 +├── weight_transform.py # Weight prep: fold global scale, pack UE4M3 ├── symm_buffer.py # GPU buffer allocation for MoE dispatch │ └── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel) - ├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel - ├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward) - ├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop) - ├── sf_layout.py # CUTLASS SF interleaved layout math - ├── setup.py # Build config (nvcc, CUTLASS include paths) - ├── build.sh # Build script - ├── test_gemm.py # Standalone test + ├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel + ├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward) + ├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop) + ├── sf_layout.py # CUTLASS SF layout reference docs + ├── setup.py # Build config (nvcc, CUTLASS include paths) + ├── build.sh # Build script + ├── test_gemm.py # Standalone test └── README.md ``` @@ -115,12 +115,12 @@ nvfp4_megamoe_kernel/ | File | When it runs | What it does | |------|-------------|--------------| -| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales, packs UE4M3 into uint32, interleaves L1 gate_up weights. Output: `((l1_w, l1_sf), (l2_w, l2_sf))` | +| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales, packs UE4M3 into uint32. Output: `((l1_w, l1_sf), (l2_w, l2_sf))` | | `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. | -| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains `stage_activation` (BF16→FP4 quantize for L1→L2). | +| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains `stage_activation` (BF16→FP4 quantize for L1→L2) and `unpack_ue4m3_u32` (uint32 packed scales → float8). | | `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Per-expert loop: gather tokens for each expert, call CUTLASS GEMM, scatter results with routing weights. | | `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side scale factor remap (row-major → CUTLASS interleaved layout). | -| `cutlass_nvfp4_gemm/sf_layout.py` | Build time / reference | Documents the CUTLASS SfAtom layout. Currently unused at runtime (remap is in CUDA). | +| `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). | --- @@ -141,6 +141,128 @@ nvfp4_megamoe_kernel/ --- +## CUTLASS Scale Factor Remap + +CUTLASS's `Sm1xxBlockScaledConfig` expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is: + +``` +Atom Shape: Shape, Shape<16, 4>> +Atom Stride: Stride, Stride<0, 1>> +Tiling: Step<_2, _1> (M tiled with step 2, K with step 1) +``` + +Our source data is row-major `(M, K_sf)` where `K_sf = K / 16`. The remap kernel (`remap_sf_to_cutlass_kernel` in `cutlass_nvfp4_gemm.cu`) converts from row-major to CUTLASS's interleaved layout. + +### How the remap works + +The kernel iterates over CUTLASS destination indices, uses `cute::idx2crd` to get the hierarchical coordinate, then `cute::flatten` to get a flat tuple of 8 sub-indices. From those, we extract logical `(m, k_sf)` and read from the row-major source. + +### Flattened coordinate decomposition (flat_rank=8) + +From the SfAtom layout with Step<_2, _1> tiling, `flatten(idx2crd(idx, ...))` produces 8 values: + +``` +f0 = inner_m (0..31) — varies fastest within M atom +f1 = sub_m (0..3) — second M sub-coordinate +f2 = tile_m (0..) — M tile index +f3 = step_m stride — degenerate (always = sfa_size, not a coordinate) +f4 = sub_k (0..3) — K sub-coordinate within atom +f5 = tile_k (0..) — K tile index +f6 = 0 — unused +f7 = 0 — unused +``` + +#### Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256) + +| idx | f0 | f1 | f2 | f3 | f4 | f5 | f6 | f7 | +| ----- | --- | --- | --- | --- | --- | --- | --- | --- | +| 0 | 0 | 0 | 0 | T | 0 | 0 | 0 | 0 | +| 1 | 0 | 0 | 0 | T | 1 | 0 | 0 | 0 | +| 4 | 0 | 1 | 0 | T | 0 | 0 | 0 | 0 | +| 16 | 1 | 0 | 0 | T | 0 | 0 | 0 | 0 | +| 511 | 31 | 3 | 0 | T | 3 | 0 | 0 | 0 | +| 512 | 0 | 0 | 0 | T | 0 | 1 | 0 | 0 | +| 1024 | 0 | 0 | 0 | T | 0 | 2 | 0 | 0 | +| 2048 | 0 | 0 | 0 | T | 0 | 4 | 0 | 0 | +| 4096 | 0 | 0 | 0 | T | 0 | 8 | 0 | 0 | +| 8192 | 0 | 0 | 0 | T | 0 | 16 | 0 | 0 | +| 65536 | 0 | 0 | 1 | T | 0 | 16 | 0 | 0 | +| 131072 | 0 | 0 | 2 | T | 0 | 32 | 0 | 0 | + +#### Extraction formula + +CuTe uses "first sub varies fastest" for `Shape<32, 4>`: + +```cpp +m = f0 + f1 * 32 + f2 * 128; +k_sf = f4 + f5 * 4; +``` + +This was verified with 6 independent probes: + +| Probe | Source | Expected | Result | +|-------|--------|----------|--------| +| SFA[1, 0] = 2.0 | row 1 changes | ✅ only row 1 | Confirms f0 term | +| SFA[32, 0] = 2.0 | row 32 changes | ✅ only row 32 | Confirms f1*32, rules out f0*4+f1 | +| SFA[128, 0] = 2.0 | row 128 changes | ✅ only row 128 | Confirms f2*128 | +| SFA[0, 1] = 2.0 | row 0 changes (k=1) | ✅ only row 0 | Confirms f4 term | +| SFA[0, 4] = 2.0 | row 0 changes (k=4) | ✅ only row 0 | Confirms f5*4 term | +| SFA[0, 100] = 2.0 | row 0 changes (k=100) | ✅ only row 0 | Confirms tile-overflow range | + +#### Why the previous remap was broken + +The previous code used `cute::get<0>(flat)` and `cute::get<1>(flat)` to extract (m, k). Since flatten produces `(inner_m, sub_m, tile_m, ...)` in order, `get<0>` and `get<1>` are both **M sub-indices** — they carry no K information. This caused only `k_group=0` to work; all other K-groups were silently mapped to the wrong source offset. + +Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than `M * K_sf`. Unmapped padding slots reading garbage caused sporadic wrong results. + +--- + +## Bugs Found & Fixed + +### 1. unpack_ue4m3_u32: value cast vs bit reinterpret +**File:** `nvfp4_mega_moe.py` +**Bug:** `(x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)` converts integer 63 → float8(63.0). +**Fix:** `(x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)` reinterprets bit pattern 0x3F → float8(~0.984). +**Also:** `uint32` lacks CUDA bitwise ops — cast to `int32` first. +**Impact:** Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe. + +### 2. stage_activation: three independent bugs +**File:** `nvfp4_moe.py` +**Bug A:** `clamp(0, 15)` zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag). +**Bug B:** Stored `block_max` but divided by `block_max/6.0` → stored scale was 6× too large. +**Bug C:** Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2. +**Fix:** Rewrote with proper nearest-neighbor E2M1 quantization. +**Impact:** Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top. + +### 3. _fold_global_scale: logical_widths branch +**File:** `weight_transform.py` +**Bug:** `logical_widths=[3072, 3072]` caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded. +**Fix:** Removed the `logical_widths` branch entirely. The `else` branch correctly broadcasts each expert's own `(E, 1)` global scale across `(E, N, K//16)`. + +### 4. L1 weight interleave removed +**File:** `weight_transform.py` +**Bug:** `_interleave_l1_weights` assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat `[gate; up]` along the output dim, and our CUTLASS kernel uses `ClusterShape<1, 1, 1>`. +**Fix:** Removed entirely. `l1_weight_out = l1_weight.contiguous()`. + +### 5. SF remap: idx2crd+flatten coordinate extraction +**File:** `cutlass_nvfp4_gemm.cu` +**Bug:** `cute::flatten(coord)` produces 8 sub-indices (flat_rank=8). `get<0>` and `get<1>` are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong. +**Fix:** Correct extraction: `m = f0 + f1*32 + f2*128`, `k_sf = f4 + f5*4`. Zero-init dest buffer before remap. +**Diagnostic trail:** Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula. + +### Diagnostic: constant-scale test (smoking gun for SF bugs) + +When all scale factors are set to UE4M3(1.0): +- **Cosine = 1.0000, MSE = 0.19** (expected FP4 quantization noise) + +With real (variable) scale factors and the broken remap: +- **Cosine = 0.83** → scales are misaligned, not fundamentally broken + +After the fix with correct coordinate extraction: +- **Cosine = 1.0000, MSE = 0.0** → perfect match with dequantized reference + +--- + ## Build & Deploy (B200) ```bash @@ -162,13 +284,11 @@ The CUTLASS extension builds inside the container during `pip install` of the nv ## Known Issues -1. **Shared experts bypassed** — FlashInfer/DeepGEMM TF32 GEMM crashes the vLLM worker. Currently returning zeros for shared expert output. This produces garbage text. +1. **MoE dispatch is slow** — `cutlass_grouped_nvfp4_gemm` uses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch. -2. **MoE dispatch is slow** — `cutlass_grouped_nvfp4_gemm` uses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch. +2. **stage_activation is Python** — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging. -3. **stage_activation is approximate** — Simple per-token absmax quantization for L1→L2 re-quant. Should use proper E2M1 quantization matching vLLM's staging kernel. - -4. **Scale factor remap adds overhead** — GPU kernel remaps row-major → CUTLASS interleaved layout every GEMM call. Should pre-compute during weight transform. +3. **SF remap allocates every call** — `cudaMemset` + remap kernel runs per GEMM invocation. Could pre-compute the CUTLASS-layout buffer once during weight transform. --- @@ -178,5 +298,12 @@ The CUTLASS extension builds inside the container during `pip install` of the nv |----------|---------|-------------| | `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) | | `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging | -| `MEGA_MOE_USE_CUTLASS` | 1 | Use CUTLASS path (always 1 now, TileLang removed) | | `SKIP_ATTENTION` | 0 | Skip attention layers (debug) | + +--- + +## Repos + +- **Kernel:** `sweetapi.com/biondizzle/nvfp4-megamoe-kernel` (branch: master) +- **Deployment:** `sweetapi.com/biondizzle/deepseek-v4-quant` (branch: modelopt-nvfp4) +- **Local:** `~/dev/nvfp4-mojo/`, `~/dev/deepseek-v4-quant/`