DeepSeek-V4-Pro has 384 routed experts, 48 per rank (384/8). The cross-rank all-reduce happens in the parent DeepseekV4MoE.forward, not in our kernel. Our kernel writes local output; caller does reduce. Fixed README, nvfp4_mega_moe.py comments.
310 lines
15 KiB
Markdown
310 lines
15 KiB
Markdown
# nvfp4-megamoe-kernel
|
||
|
||
Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).
|
||
|
||
Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUTLASS-based implementation that emits real `SM100_MMA_MXF4_SS` tensor core instructions.
|
||
|
||
---
|
||
|
||
## Architecture
|
||
|
||
DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts.
|
||
|
||
### The MoE Forward Pass
|
||
|
||
```
|
||
Input hidden states (BF16)
|
||
│
|
||
▼
|
||
┌─────────────────┐
|
||
│ Shared Experts │ ← vLLM native FlashInfer CUTLASS NVFP4 path
|
||
│ (gate + up → │ (not our kernel)
|
||
│ SiLU * up → │
|
||
│ down) │
|
||
└─────────────────┘
|
||
│
|
||
▼
|
||
Staging Kernel (vLLM built-in)
|
||
BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
|
||
Writes to SymmBuffer.x / SymmBuffer.x_sf
|
||
│
|
||
▼
|
||
Router (vLLM built-in)
|
||
Writes topk_ids / topk_weights to SymmBuffer
|
||
│
|
||
▼
|
||
┌─────────────────────────────────────────┐
|
||
│ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py
|
||
│ │
|
||
│ 1. Read staged activation from buffer │
|
||
│ 2. L1 GEMM: gate_up_proj │ ← CUTLASS NVFP4 block-scaled
|
||
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
|
||
│ → BF16 output (6144-wide) │
|
||
│ 3. SiLU(gate) * up (activation) │
|
||
│ 4. stage_activation: BF16 → FP4 │ ← proper E2M1 quantization
|
||
│ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled
|
||
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
|
||
│ → BF16 output (7168-wide) │
|
||
│ 6. Write to output tensor │ ← caller handles cross-rank all-reduce
|
||
└─────────────────────────────────────────┘
|
||
```
|
||
|
||
### vLLM Startup Sequence (how our code plugs in)
|
||
|
||
```
|
||
1. vLLM engine init
|
||
└─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
|
||
└─ FlashInferCutlassNvFp4LinearKernel for linear layers
|
||
|
||
2. Model construction
|
||
└─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
|
||
Each layer has: attention + MoE block
|
||
MoE block has: shared experts + 384 routed experts
|
||
|
||
3. Weight loading
|
||
└─ 95 safetensor shards loaded
|
||
└─ weight, weight_scale, weight_scale_2 loaded per linear
|
||
|
||
4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN
|
||
└─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
|
||
└─ finalize_mega_moe_weights()
|
||
└─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
|
||
• Folds weight_scale_2 (global scale) into weight_scale (block scale)
|
||
• UE4M3 block-16 scales: 4 values packed per uint32
|
||
• Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank
|
||
|
||
5. SymmBuffer allocation
|
||
└─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe()
|
||
• Pre-allocates GPU buffers for:
|
||
- x: int8 packed E2M1 activations
|
||
- x_sf: uint32 packed UE4M3 activation scales
|
||
- topk_idx: int32 expert indices
|
||
- topk_weights: float32 routing weights
|
||
- buffer: BF16 all-reduce buffer
|
||
|
||
6. Profile run (warmup)
|
||
└─ First forward pass to allocate KV cache, etc.
|
||
└─ This is where the CUTLASS GEMM first executes
|
||
|
||
7. Ready to serve
|
||
```
|
||
|
||
---
|
||
|
||
## File Map
|
||
|
||
```
|
||
nvfp4_megamoe_kernel/
|
||
├── __init__.py # Public API exports
|
||
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, nvfp4_mega_moe_l1/l2, stage_activation
|
||
├── weight_transform.py # Weight prep: fold global scale, pack UE4M3
|
||
├── symm_buffer.py # GPU buffer allocation for MoE dispatch
|
||
│
|
||
└── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel)
|
||
├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel
|
||
├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward)
|
||
├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop)
|
||
├── sf_layout.py # CUTLASS SF layout reference docs
|
||
├── setup.py # Build config (nvcc, CUTLASS include paths)
|
||
├── build.sh # Build script
|
||
├── test_gemm.py # Standalone test
|
||
└── README.md
|
||
```
|
||
|
||
### What each file does (in call order)
|
||
|
||
| File | When it runs | What it does |
|
||
|------|-------------|--------------|
|
||
| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as `float8_e4m3fn` (not packed uint32). Output: `((l1_w, l1_sf), (l2_w, l2_sf))` |
|
||
| `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. |
|
||
| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains `stage_activation` (BF16→FP4 quantize for L1→L2) and `unpack_ue4m3_u32` (uint32 packed scales → float8). |
|
||
| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Per-expert loop: gather tokens for each expert, call CUTLASS GEMM, scatter results with routing weights. |
|
||
| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side scale factor remap (row-major → CUTLASS interleaved layout). |
|
||
| `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). |
|
||
|
||
---
|
||
|
||
## Data Formats
|
||
|
||
### Weights
|
||
- **Packed E2M1** (`int8`): 2 FP4 values per byte. Shape: `(E_per_rank, N, K//2)`, K-major layout.
|
||
- **UE4M3 block scales** (`float8_e4m3fn`): 1 scale per 16 FP4 values (group_size=16). Shape: `(E_per_rank, N, K//16)`. Returned as `float8_e4m3fn` from `weight_transform.py` — NOT packed uint32. The CUTLASS GEMM consumes float8 directly.
|
||
|
||
### Activations (after staging kernel)
|
||
- **Packed E2M1** (`int8`): Shape: `(num_tokens, K//2)`.
|
||
- **UE4M3 scales** (`uint32`): 4 UE4M3 values packed per uint32. Shape: `(num_tokens, K//64)`. Unpacked to `float8_e4m3fn` via `unpack_ue4m3_u32` before reaching the CUTLASS GEMM.
|
||
|
||
### GEMM dimensions (DeepSeek-V4-Pro)
|
||
- **L1 (gate_up_proj):** M×6144×7168 (per expert)
|
||
- **L2 (down_proj):** M×7168×3072 (per expert)
|
||
- 48 experts per rank (384 total / 8 ranks), top-6 routing
|
||
|
||
---
|
||
|
||
## CUTLASS Scale Factor Remap
|
||
|
||
CUTLASS's `Sm1xxBlockScaledConfig` expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is:
|
||
|
||
```
|
||
Atom Shape: Shape<Shape<32, 4>, Shape<16, 4>>
|
||
Atom Stride: Stride<Stride<16, 4>, Stride<0, 1>>
|
||
Tiling: Step<_2, _1> (M tiled with step 2, K with step 1)
|
||
```
|
||
|
||
Our source data is row-major `(M, K_sf)` where `K_sf = K / 16`. The remap kernel (`remap_sf_to_cutlass_kernel` in `cutlass_nvfp4_gemm.cu`) converts from row-major to CUTLASS's interleaved layout.
|
||
|
||
### How the remap works
|
||
|
||
The kernel iterates over CUTLASS destination indices, uses `cute::idx2crd` to get the hierarchical coordinate, then `cute::flatten` to get a flat tuple of 8 sub-indices. From those, we extract logical `(m, k_sf)` and read from the row-major source.
|
||
|
||
### Flattened coordinate decomposition (flat_rank=8)
|
||
|
||
From the SfAtom layout with Step<_2, _1> tiling, `flatten(idx2crd(idx, ...))` produces 8 values:
|
||
|
||
```
|
||
f0 = inner_m (0..31) — varies fastest within M atom
|
||
f1 = sub_m (0..3) — second M sub-coordinate
|
||
f2 = tile_m (0..) — M tile index
|
||
f3 = step_m stride — degenerate (always = sfa_size, not a coordinate)
|
||
f4 = sub_k (0..3) — K sub-coordinate within atom
|
||
f5 = tile_k (0..) — K tile index
|
||
f6 = 0 — unused
|
||
f7 = 0 — unused
|
||
```
|
||
|
||
#### Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256)
|
||
|
||
| idx | f0 | f1 | f2 | f3 | f4 | f5 | f6 | f7 |
|
||
| ----- | --- | --- | --- | --- | --- | --- | --- | --- |
|
||
| 0 | 0 | 0 | 0 | T | 0 | 0 | 0 | 0 |
|
||
| 1 | 0 | 0 | 0 | T | 1 | 0 | 0 | 0 |
|
||
| 4 | 0 | 1 | 0 | T | 0 | 0 | 0 | 0 |
|
||
| 16 | 1 | 0 | 0 | T | 0 | 0 | 0 | 0 |
|
||
| 511 | 31 | 3 | 0 | T | 3 | 0 | 0 | 0 |
|
||
| 512 | 0 | 0 | 0 | T | 0 | 1 | 0 | 0 |
|
||
| 1024 | 0 | 0 | 0 | T | 0 | 2 | 0 | 0 |
|
||
| 2048 | 0 | 0 | 0 | T | 0 | 4 | 0 | 0 |
|
||
| 4096 | 0 | 0 | 0 | T | 0 | 8 | 0 | 0 |
|
||
| 8192 | 0 | 0 | 0 | T | 0 | 16 | 0 | 0 |
|
||
| 65536 | 0 | 0 | 1 | T | 0 | 16 | 0 | 0 |
|
||
| 131072 | 0 | 0 | 2 | T | 0 | 32 | 0 | 0 |
|
||
|
||
#### Extraction formula
|
||
|
||
CuTe uses "first sub varies fastest" for `Shape<32, 4>`:
|
||
|
||
```cpp
|
||
m = f0 + f1 * 32 + f2 * 128;
|
||
k_sf = f4 + f5 * 4;
|
||
```
|
||
|
||
This was verified with 6 independent probes:
|
||
|
||
| Probe | Source | Expected | Result |
|
||
|-------|--------|----------|--------|
|
||
| SFA[1, 0] = 2.0 | row 1 changes | ✅ only row 1 | Confirms f0 term |
|
||
| SFA[32, 0] = 2.0 | row 32 changes | ✅ only row 32 | Confirms f1*32, rules out f0*4+f1 |
|
||
| SFA[128, 0] = 2.0 | row 128 changes | ✅ only row 128 | Confirms f2*128 |
|
||
| SFA[0, 1] = 2.0 | row 0 changes (k=1) | ✅ only row 0 | Confirms f4 term |
|
||
| SFA[0, 4] = 2.0 | row 0 changes (k=4) | ✅ only row 0 | Confirms f5*4 term |
|
||
| SFA[0, 100] = 2.0 | row 0 changes (k=100) | ✅ only row 0 | Confirms tile-overflow range |
|
||
|
||
#### Why the previous remap was broken
|
||
|
||
The previous code used `cute::get<0>(flat)` and `cute::get<1>(flat)` to extract (m, k). Since flatten produces `(inner_m, sub_m, tile_m, ...)` in order, `get<0>` and `get<1>` are both **M sub-indices** — they carry no K information. This caused only `k_group=0` to work; all other K-groups were silently mapped to the wrong source offset.
|
||
|
||
Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than `M * K_sf`. Unmapped padding slots reading garbage caused sporadic wrong results.
|
||
|
||
---
|
||
|
||
## Bugs Found & Fixed
|
||
|
||
### 1. unpack_ue4m3_u32: value cast vs bit reinterpret
|
||
**File:** `nvfp4_mega_moe.py`
|
||
**Bug:** `(x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)` converts integer 63 → float8(63.0).
|
||
**Fix:** `(x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)` reinterprets bit pattern 0x3F → float8(~0.984).
|
||
**Also:** `uint32` lacks CUDA bitwise ops — cast to `int32` first.
|
||
**Impact:** Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe.
|
||
|
||
### 2. stage_activation: three independent bugs
|
||
**File:** `nvfp4_moe.py`
|
||
**Bug A:** `clamp(0, 15)` zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag).
|
||
**Bug B:** Stored `block_max` but divided by `block_max/6.0` → stored scale was 6× too large.
|
||
**Bug C:** Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2.
|
||
**Fix:** Rewrote with proper nearest-neighbor E2M1 quantization.
|
||
**Impact:** Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top.
|
||
|
||
### 3. _fold_global_scale: logical_widths branch
|
||
**File:** `weight_transform.py`
|
||
**Bug:** `logical_widths=[3072, 3072]` caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded.
|
||
**Fix:** Removed the `logical_widths` branch entirely. The `else` branch correctly broadcasts each expert's own `(E, 1)` global scale across `(E, N, K//16)`.
|
||
|
||
### 4. L1 weight interleave removed
|
||
**File:** `weight_transform.py`
|
||
**Bug:** `_interleave_l1_weights` assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat `[gate; up]` along the output dim, and our CUTLASS kernel uses `ClusterShape<1, 1, 1>`.
|
||
**Fix:** Removed entirely. `l1_weight_out = l1_weight.contiguous()`.
|
||
|
||
### 5. SF remap: idx2crd+flatten coordinate extraction
|
||
**File:** `cutlass_nvfp4_gemm.cu`
|
||
**Bug:** `cute::flatten(coord)` produces 8 sub-indices (flat_rank=8). `get<0>` and `get<1>` are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong.
|
||
**Fix:** Correct extraction: `m = f0 + f1*32 + f2*128`, `k_sf = f4 + f5*4`. Zero-init dest buffer before remap.
|
||
**Diagnostic trail:** Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula.
|
||
|
||
### Diagnostic: constant-scale test (smoking gun for SF bugs)
|
||
|
||
When all scale factors are set to UE4M3(1.0):
|
||
- **Cosine = 1.0000, MSE = 0.19** (expected FP4 quantization noise)
|
||
|
||
With real (variable) scale factors and the broken remap:
|
||
- **Cosine = 0.83** → scales are misaligned, not fundamentally broken
|
||
|
||
After the fix with correct coordinate extraction:
|
||
- **Cosine = 1.0000, MSE = 0.0** → perfect match with dequantized reference
|
||
|
||
---
|
||
|
||
## Build & Deploy (B200)
|
||
|
||
```bash
|
||
# On B200 host — CUTLASS must be cloned and mounted
|
||
cd /root/nvidia-meeting/deepseek-v4-quant/
|
||
|
||
# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
|
||
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
|
||
docker compose up -d
|
||
```
|
||
|
||
The CUTLASS extension builds inside the container during `pip install` of the nvfp4-megamoe-kernel package. It needs:
|
||
- CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
|
||
- CUTLASS headers at `/root/cutlass/include/`
|
||
- CCCL headers at `/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/`
|
||
- Device with SM100 compute capability (B200)
|
||
|
||
---
|
||
|
||
## Known Issues
|
||
|
||
1. **MoE dispatch is slow** — `cutlass_grouped_nvfp4_gemm` uses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch.
|
||
|
||
2. **stage_activation is Python** — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging.
|
||
|
||
3. **SF remap allocates every call** — `cudaMemset` + remap kernel runs per GEMM invocation. Could pre-compute the CUTLASS-layout buffer once during weight transform.
|
||
|
||
---
|
||
|
||
## Environment Variables
|
||
|
||
| Variable | Default | Description |
|
||
|----------|---------|-------------|
|
||
| `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) |
|
||
| `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging |
|
||
| `SKIP_ATTENTION` | 0 | Skip attention layers (debug) |
|
||
|
||
---
|
||
|
||
## Repos
|
||
|
||
- **Kernel:** `sweetapi.com/biondizzle/nvfp4-megamoe-kernel` (branch: master)
|
||
- **Deployment:** `sweetapi.com/biondizzle/deepseek-v4-quant` (branch: modelopt-nvfp4)
|
||
- **Local:** `~/dev/nvfp4-mojo/`, `~/dev/deepseek-v4-quant/`
|