Files
nvfp4-megamoe-kernel/README.md
biondizzle 9908fd64d9 feat: CUTLASS NVFP4 mega_moe kernel — slot-based L1/L2, source-first SF remap
Major changes from initial TileLang prototype:

Kernel:
- CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp)
- Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter
- 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild)
- slot_token gathered in cutlass_grouped_nvfp4_gemm when provided

SF Remap (source-first):
- Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf))
  for CUTLASS dest index — no idx2crd/flatten coordinate extraction
- 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN)
- Uses cute::cosize() for physical allocation size (not cute::size)
- SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major)

Weight transform:
- UE4M3 unpack with bit reinterpret (not value cast)
- Global scale folding (weight_scale_2) for gate/up split
- clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS

No prepack cache:
- SFB remapped per-call inside CUTLASS (~µs, not the bottleneck)
- See README for why prepack cache must never return (OOM, CUDA graphs,
  M-dependent layout, cross-layer collisions)

Stage activation:
- Nearest-neighbor E2M1 quantization (no clamp, no uniform steps)
- Per-tensor global scale → alpha for L2 GEMM

Bug fixes:
- _fold_global_scale: removed broken logical_widths branch
- unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support
- Correct expert param mapping for NVFP4 checkpoint
- SiLU applied per-slot (not after summing expert paths)
2026-05-15 11:38:18 +00:00

362 lines
19 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# nvfp4-megamoe-kernel
Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).
Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUTLASS-based implementation that emits real `SM100_MMA_MXF4_SS` tensor core instructions.
---
## Architecture
DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts.
### The MoE Forward Pass
```
Input hidden states (BF16)
┌─────────────────┐
│ Shared Experts │ ← vLLM native FlashInfer CUTLASS NVFP4 path
│ (gate + up → │ (not our kernel)
│ SiLU * up → │
│ down) │
└─────────────────┘
Staging Kernel (vLLM built-in)
BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
Writes to SymmBuffer.x / SymmBuffer.x_sf
Router (vLLM built-in)
Writes topk_ids / topk_weights to SymmBuffer
┌─────────────────────────────────────────────────┐
│ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py
│ │
│ 1. Read staged activation from buffer │
│ 2. Build slot mapping (token, topk) → local │
│ expert, routing weight │
│ 3. L1 GEMM: gate_up_proj (slot-based) │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 per-slot output (6144-wide) │
│ 4. SiLU(gate) * up PER SLOT │
│ (nonlinearity before combining paths) │
│ 5. stage_activation: BF16 → FP4 │ ← proper E2M1 quantization
│ 6. L2 GEMM: down_proj (slot-based) │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 per-slot output (7168-wide) │
│ 7. Final scatter: │
│ y.index_add_(slot_token, │
│ slot_weight * l2_slots) │
│ Routing weight applied ONCE at scatter │
└─────────────────────────────────────────────────┘
Cross-rank all-reduce (vLLM built-in)
```
### Slot-Based Dispatch
The kernel uses a **slot representation** instead of collapsing expert outputs early. A slot is one `(token, topk_expert)` pair. For a batch of T tokens with top-6 routing, there are up to 6T slots (fewer if some experts are out of the local rank's range).
**Why slots?** Two bugs in the previous approach:
1. **SiLU after summing is mathematically wrong.** `silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ`. The nonlinearity must happen per-expert-path before combining.
2. **Routing weights applied twice.** The old grouped GEMM applied `topk_weights` in its scatter loop, and was called for both L1 and L2 — squaring the weights.
The slot approach fixes both: SiLU+Mul happens per-slot, and routing weights are applied exactly once at the final `index_add_` scatter.
### SFB (Weight Scale Factors) — Remapped Per-Call, NOT Cached
Weight scale factors (SFB) are remapped from row-major to CUTLASS interleaved layout on every GEMM call. This is a lightweight scatter kernel (~µs) and is NOT the bottleneck compared to the GEMM itself.
⚠️ **DO NOT ADD A PREPACK CACHE FOR SFB.** Previous attempts caused critical issues:
| Problem | Impact |
|---------|--------|
| **OOM** | ~1.75 GiB per prepacked tensor × 61 MoE layers × 2 (L1+L2) = ~214 GiB — exceeds B200 capacity |
| **Peak memory 2×** | `torch.stack` held all expert tensors + final stack simultaneously before LRU eviction |
| **CUDA graph trap** | LRU eviction frees tensors that CUDA graphs still reference → use-after-free → silent corruption or crash |
| **M-dependent layout** | `prepack_sfb(M=128)` assumed SFB layout size is M-independent (never verified). If wrong, entire prepack is invalid |
| **Cross-layer cache collision** | Tag-based cache (`"l1"`/`"l2"`) returned layer N-1's data for layer N. Fixed with data_ptr key, but the cache itself was the root problem |
The per-call remap costs microseconds. The cache cost was hours of debugging. Don't repeat this mistake.
---
### vLLM Startup Sequence (how our code plugs in)
```
1. vLLM engine init
└─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
└─ FlashInferCutlassNvFp4LinearKernel for linear layers
2. Model construction
└─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
Each layer has: attention + MoE block
MoE block has: shared experts + 384 routed experts
3. Weight loading
└─ 95 safetensor shards loaded
└─ weight, weight_scale, weight_scale_2 loaded per linear
4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN
└─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
└─ finalize_mega_moe_weights()
└─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
• Folds weight_scale_2 (global scale) into weight_scale (block scale)
• UE4M3 block-16 scales: 4 values packed per uint32
• Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank
5. SymmBuffer allocation
└─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe()
• Pre-allocates GPU buffers for:
- x: int8 packed E2M1 activations
- x_sf: uint32 packed UE4M3 activation scales
- topk_idx: int32 expert indices
- topk_weights: float32 routing weights
- buffer: BF16 all-reduce buffer
6. Profile run (warmup)
└─ First forward pass to allocate KV cache, etc.
└─ This is where the CUTLASS GEMM first executes
└─ SFB weight scales remapped per-expert inside CUTLASS (no cache)
7. Ready to serve
```
---
## File Map
```
nvfp4_megamoe_kernel/
├── __init__.py # Public API exports
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, L1/L2, stage_activation
├── weight_transform.py # Weight prep: fold global scale, pack UE4M3
├── symm_buffer.py # GPU buffer allocation for MoE dispatch
└── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel)
├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap + prepack SFB + prepacked-SFB GEMM path
├── pytorch_binding.cpp # PyTorch C++ binding (forward, forward_prepacked_sfb, prepack_sfb)
├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (slot-based, per-expert loop)
├── sf_layout.py # CUTLASS SF layout reference docs
├── setup.py # Build config (nvcc, CUTLASS include paths)
├── build.sh # Build script
├── test_gemm.py # Standalone test
└── README.md
```
### What each file does (in call order)
| File | When it runs | What it does |
|------|-------------|--------------|
| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as `float8_e4m3fn` (not packed uint32). Output: `((l1_w, l1_sf), (l2_w, l2_sf))` |
| `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. |
| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → build slot mapping → L1 GEMM → SiLU+Mul per-slot → re-quantize → L2 GEMM → final index_add_ scatter with routing weights. Contains `stage_activation` (BF16→FP4) and `unpack_ue4m3_u32`. NO prepack cache — SFB remapped per-call inside CUTLASS. |
| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Slot-based per-expert loop: gather slots for each expert, call CUTLASS GEMM (SFB remapped inside C extension), write results to slot buffer. No routing weights — caller handles scatter. |
| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side SFA and SFB remap. |
| `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). |
---
## Data Formats
### Weights
- **Packed E2M1** (`int8`): 2 FP4 values per byte. Shape: `(E_per_rank, N, K//2)`, K-major layout.
- **UE4M3 block scales** (`float8_e4m3fn`): 1 scale per 16 FP4 values (group_size=16). Shape: `(E_per_rank, N, K//16)`. Returned as `float8_e4m3fn` from `weight_transform.py` — NOT packed uint32. The CUTLASS GEMM consumes float8 directly.
### Activations (after staging kernel)
- **Packed E2M1** (`int8`): Shape: `(num_tokens, K//2)`.
- **UE4M3 scales** (`uint32`): 4 UE4M3 values packed per uint32. Shape: `(num_tokens, K//64)`. Unpacked to `float8_e4m3fn` via `unpack_ue4m3_u32` before reaching the CUTLASS GEMM.
### GEMM dimensions (DeepSeek-V4-Pro)
- **L1 (gate_up_proj):** M×6144×7168 (per expert)
- **L2 (down_proj):** M×7168×3072 (per expert)
- 48 experts per rank (384 total / 8 ranks), top-6 routing
---
## CUTLASS Scale Factor Remap
CUTLASS's `Sm1xxBlockScaledConfig` expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is:
```
Atom Shape: Shape<Shape<32, 4>, Shape<16, 4>>
Atom Stride: Stride<Stride<16, 4>, Stride<0, 1>>
Tiling: Step<_2, _1> (M tiled with step 2, K with step 1)
```
Our source data is row-major `(M, K_sf)` where `K_sf = K / 16`. The remap kernel (`remap_sf_to_cutlass_kernel` in `cutlass_nvfp4_gemm.cu`) converts from row-major to CUTLASS's interleaved layout.
### How the remap works
The kernel iterates over CUTLASS destination indices, uses `cute::idx2crd` to get the hierarchical coordinate, then `cute::flatten` to get a flat tuple of 8 sub-indices. From those, we extract logical `(m, k_sf)` and read from the row-major source.
### Flattened coordinate decomposition (flat_rank=8)
From the SfAtom layout with Step<_2, _1> tiling, `flatten(idx2crd(idx, ...))` produces 8 values:
```
f0 = inner_m (0..31) — varies fastest within M atom
f1 = sub_m (0..3) — second M sub-coordinate
f2 = tile_m (0..) — M tile index
f3 = step_m stride — degenerate (always = sfa_size, not a coordinate)
f4 = sub_k (0..3) — K sub-coordinate within atom
f5 = tile_k (0..) — K tile index
f6 = 0 — unused
f7 = 0 — unused
```
#### Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256)
| idx | f0 | f1 | f2 | f3 | f4 | f5 | f6 | f7 |
| ----- | --- | --- | --- | --- | --- | --- | --- | --- |
| 0 | 0 | 0 | 0 | T | 0 | 0 | 0 | 0 |
| 1 | 0 | 0 | 0 | T | 1 | 0 | 0 | 0 |
| 4 | 0 | 1 | 0 | T | 0 | 0 | 0 | 0 |
| 16 | 1 | 0 | 0 | T | 0 | 0 |0 | 0 |
| 511 | 31 | 3 | 0 | T | 3 | 0 | 0 | 0 |
| 512 | 0 | 0 | 0 | T | 0 | 1 | 0 | 0 |
| 1024 | 0 | 0 | 0 | T | 0 | 2 | 0 | 0 |
| 2048 | 0 | 0 | 0 | T | 0 | 4 | 0 | 0 |
| 4096 | 0 | 0 | 0 | T | 0 | 8 | 0 | 0 |
| 8192 | 0 | 0 | 0 | T | 0 | 16 | 0 | 0 |
| 65536 | 0 | 0 | 1 | T | 0 | 16 | 0 | 0 |
| 131072 | 0 | 0 | 2 | T | 0 | 32 | 0 | 0 |
#### Extraction formula
CuTe uses "first sub varies fastest" for `Shape<32, 4>`:
```cpp
m = f0 + f1 * 32 + f2 * 128;
k_sf = f4 + f5 * 4;
```
This was verified with 6 independent probes:
| Probe | Source | Expected | Result |
|-------|--------|----------|--------|
| SFA[1, 0] = 2.0 | row 1 changes | ✅ only row 1 | Confirms f0 term |
| SFA[32, 0] = 2.0 | row 32 changes | ✅ only row 32 | Confirms f1*32, rules out f0*4+f1 |
| SFA[128, 0] = 2.0 | row 128 changes | ✅ only row 128 | Confirms f2*128 |
| SFA[0, 1] = 2.0 | row 0 changes (k=1) | ✅ only row 0 | Confirms f4 term |
| SFA[0, 4] = 2.0 | row 0 changes (k=4) | ✅ only row 0 | Confirms f5*4 term |
| SFA[0, 100] = 2.0 | row 0 changes (k=100) | ✅ only row 0 | Confirms tile-overflow range |
#### Why the previous remap was broken
The previous code used `cute::get<0>(flat)` and `cute::get<1>(flat)` to extract (m, k). Since flatten produces `(inner_m, sub_m, tile_m, ...)` in order, `get<0>` and `get<1>` are both **M sub-indices** — they carry no K information. This caused only `k_group=0` to work; all other K-groups were silently mapped to the wrong source offset.
Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than `M * K_sf`. Unmapped padding slots reading garbage caused sporadic wrong results.
---
## Bugs Found & Fixed
### 1. unpack_ue4m3_u32: value cast vs bit reinterpret
**File:** `nvfp4_mega_moe.py`
**Bug:** `(x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)` converts integer 63 → float8(63.0).
**Fix:** `(x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)` reinterprets bit pattern 0x3F → float8(~0.984).
**Also:** `uint32` lacks CUDA bitwise ops — cast to `int32` first.
**Impact:** Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe.
### 2. stage_activation: three independent bugs
**File:** `nvfp4_moe.py`
**Bug A:** `clamp(0, 15)` zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag).
**Bug B:** Stored `block_max` but divided by `block_max/6.0` → stored scale was 6× too large.
**Bug C:** Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2.
**Fix:** Rewrote with proper nearest-neighbor E2M1 quantization.
**Impact:** Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top.
### 3. _fold_global_scale: logical_widths branch
**File:** `weight_transform.py`
**Bug:** `logical_widths=[3072, 3072]` caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded.
**Fix:** Removed the `logical_widths` branch entirely. The `else` branch correctly broadcasts each expert's own `(E, 1)` global scale across `(E, N, K//16)`.
### 4. L1 weight interleave removed (transpose still needed)
**File:** `weight_transform.py`
**Bug:** `_interleave_l1_weights` assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat `[gate; up]` along the output dim, and our CUTLASS kernel uses `ClusterShape<1, 1, 1>`.
**Fix:** Removed the interleave function. Weights still need a transpose from checkpoint layout `(N, K_half)` row-major to CUTLASS layout `(K_half, N)` column-major — this is standard row→column conversion, not interleaving. Both L1 and L2 weights and scales are transposed.
### 5. SF remap: idx2crd+flatten coordinate extraction
**File:** `cutlass_nvfp4_gemm.cu`
**Bug:** `cute::flatten(coord)` produces 8 sub-indices (flat_rank=8). `get<0>` and `get<1>` are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong.
**Fix:** Correct extraction: `m = f0 + f1*32 + f2*128`, `k_sf = f4 + f5*4`. Zero-init dest buffer before remap.
**Diagnostic trail:** Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula.
### 6. SiLU after summing expert paths (math error)
**File:** `nvfp4_mega_moe.py`
**Bug:** The old grouped GEMM collapsed expert outputs into a weighted sum, then applied SiLU+Mul on the sum. `silu(Σ wᵢ·gateᵢ) * (Σ wᵢ·upᵢ) ≠ Σ wᵢ·silu(gateᵢ)·upᵢ`. The nonlinearity must happen per-expert-path.
**Fix:** Slot-based dispatch — L1 GEMM returns per-slot output, SiLU+Mul applied per-slot, L2 GEMM per-slot, routing weights applied once at final `index_add_` scatter.
### 7. Routing weights applied twice
**File:** `cutlass_nvfp4_gemm/kernel.py`
**Bug:** `cutlass_grouped_nvfp4_gemm` applied `topk_weights` in its scatter loop. Called for both L1 and L2, each expert's contribution was scaled by `topk_weight²`.
**Fix:** GEMM returns per-slot results with no routing weights. Single `y.index_add_(0, slot_token, slot_weight * l2_slots)` at the end.
### Diagnostic: constant-scale test (smoking gun for SF bugs)
When all scale factors are set to UE4M3(1.0):
- **Cosine = 1.0000, MSE = 0.19** (expected FP4 quantization noise)
With real (variable) scale factors and the broken remap:
- **Cosine = 0.83** → scales are misaligned, not fundamentally broken
After the fix with correct coordinate extraction:
- **Cosine = 1.0000, MSE = 0.0** → perfect match with dequantized reference
---
## Build & Deploy (B200)
```bash
# On B200 host — CUTLASS must be cloned and mounted
cd /root/nvidia-meeting/deepseek-v4-quant/
# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
docker compose up -d
```
The CUTLASS extension builds inside the container during `pip install` of the nvfp4-megamoe-kernel package. It needs:
- CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
- CUTLASS headers at `/root/cutlass/include/`
- CCCL headers at `/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/`
- Device with SM100 compute capability (B200)
---
## Known Issues / TODO
1. ~~**MoE dispatch is slow**~~ — Fixed. Slot-based `index_add_` replaces the Python double loop over tokens×topk. Routing weights applied once at final scatter.
2. **stage_activation is Python** — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging.
3. ~~**SF remap allocates every call**~~ — Fixed. SFB weight scales are prepacked into CUTLASS layout once (lazy, cached per layer). Only SFA (activation scales) remapped dynamically.
4. **Per-expert GEMM dispatch is serial Python loop** — The `cutlass_grouped_nvfp4_gemm` iterates over 48 experts in a Python `for` loop. Each iteration launches one CUTLASS GEMM. Could benefit from a true grouped GEMM kernel or CUDA-side expert dispatch.
---
## Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) |
| `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging |
| `SKIP_ATTENTION` | 0 | Skip attention layers (debug) |
---
## Repos
- **Kernel:** `sweetapi.com/biondizzle/nvfp4-megamoe-kernel` (branch: master)
- **Deployment:** `sweetapi.com/biondizzle/deepseek-v4-quant` (branch: modelopt-nvfp4)
- **Local:** `~/dev/nvfp4-megamoe-kernel/`, `~/dev/deepseek-v4-quant/`