Files
nvfp4-megamoe-kernel/README.md
biondizzle 128ff84358 fix: 384 experts (not 256), clarify cross-rank reduce is in caller
DeepSeek-V4-Pro has 384 routed experts, 48 per rank (384/8).
The cross-rank all-reduce happens in the parent DeepseekV4MoE.forward,
not in our kernel. Our kernel writes local output; caller does reduce.
Fixed README, nvfp4_mega_moe.py comments.
2026-05-14 17:33:59 +00:00

310 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# nvfp4-megamoe-kernel
Native NVFP4 block-scaled MoE kernel for DeepSeek-V4-Pro on NVIDIA Blackwell (SM100).
Replaces the broken `fp8_nvfp4_mega_moe` kernel from DeepGEMM with a working CUTLASS-based implementation that emits real `SM100_MMA_MXF4_SS` tensor core instructions.
---
## Architecture
DeepSeek-V4-Pro is a 384-expert MoE model with expert parallelism across 8 ranks (B200 GPUs). Each rank handles 48 experts. For each token, the router picks the top-6 experts.
### The MoE Forward Pass
```
Input hidden states (BF16)
┌─────────────────┐
│ Shared Experts │ ← vLLM native FlashInfer CUTLASS NVFP4 path
│ (gate + up → │ (not our kernel)
│ SiLU * up → │
│ down) │
└─────────────────┘
Staging Kernel (vLLM built-in)
BF16 → packed E2M1 (int8) + UE4M3 block-16 scales (uint32)
Writes to SymmBuffer.x / SymmBuffer.x_sf
Router (vLLM built-in)
Writes topk_ids / topk_weights to SymmBuffer
┌─────────────────────────────────────────┐
│ nvfp4_mega_moe_full │ ← nvfp4_mega_moe.py
│ │
│ 1. Read staged activation from buffer │
│ 2. L1 GEMM: gate_up_proj │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 output (6144-wide) │
│ 3. SiLU(gate) * up (activation) │
│ 4. stage_activation: BF16 → FP4 │ ← proper E2M1 quantization
│ 5. L2 GEMM: down_proj │ ← CUTLASS NVFP4 block-scaled
│ E2M1 × E2M1 + UE4M3 scales │ SM100_MMA_MXF4_SS PTX
│ → BF16 output (7168-wide) │
│ 6. Write to output tensor │ ← caller handles cross-rank all-reduce
└─────────────────────────────────────────┘
```
### vLLM Startup Sequence (how our code plugs in)
```
1. vLLM engine init
└─ ModelOptNvFp4Config selected (NVFP4 quantization scheme)
└─ FlashInferCutlassNvFp4LinearKernel for linear layers
2. Model construction
└─ DeepseekV4ForCausalLM → DeepseekV4MoE → DeepseekV4DecoderLayer
Each layer has: attention + MoE block
MoE block has: shared experts + 384 routed experts
3. Weight loading
└─ 95 safetensor shards loaded
└─ weight, weight_scale, weight_scale_2 loaded per linear
4. process_weights_after_loading ← THIS IS WHERE WE HOOK IN
└─ ModelOptNvFp4LinearMethod swizzles/pads weights for CUTLASS
└─ finalize_mega_moe_weights()
└─ weight_transform.py: transform_nvfp4_weights_for_mega_moe()
• Folds weight_scale_2 (global scale) into weight_scale (block scale)
• UE4M3 block-16 scales: 4 values packed per uint32
• Returns ((l1_w, l1_sf), (l2_w, l2_sf)) per rank
5. SymmBuffer allocation
└─ symm_buffer.py: get_symm_buffer_for_nvfp4_mega_moe()
• Pre-allocates GPU buffers for:
- x: int8 packed E2M1 activations
- x_sf: uint32 packed UE4M3 activation scales
- topk_idx: int32 expert indices
- topk_weights: float32 routing weights
- buffer: BF16 all-reduce buffer
6. Profile run (warmup)
└─ First forward pass to allocate KV cache, etc.
└─ This is where the CUTLASS GEMM first executes
7. Ready to serve
```
---
## File Map
```
nvfp4_megamoe_kernel/
├── __init__.py # Public API exports
├── nvfp4_mega_moe.py # Main kernel: nvfp4_mega_moe_full, nvfp4_mega_moe_l1/l2, stage_activation
├── weight_transform.py # Weight prep: fold global scale, pack UE4M3
├── symm_buffer.py # GPU buffer allocation for MoE dispatch
└── cutlass_nvfp4_gemm/ # CUTLASS CUDA extension (the actual hardware kernel)
├── cutlass_nvfp4_gemm.cu # CUDA: CUTLASS GEMM + SF remap kernel
├── pytorch_binding.cpp # PyTorch C++ binding (_C.forward)
├── kernel.py # Python: cutlass_grouped_nvfp4_gemm (per-expert loop)
├── sf_layout.py # CUTLASS SF layout reference docs
├── setup.py # Build config (nvcc, CUTLASS include paths)
├── build.sh # Build script
├── test_gemm.py # Standalone test
└── README.md
```
### What each file does (in call order)
| File | When it runs | What it does |
|------|-------------|--------------|
| `weight_transform.py` | Once at startup (weight loading) | Takes raw NVFP4 checkpoint weights, folds global scales into block scales. Returns scales as `float8_e4m3fn` (not packed uint32). Output: `((l1_w, l1_sf), (l2_w, l2_sf))` |
| `symm_buffer.py` | Once at startup (buffer alloc) | Pre-allocates GPU tensors for activations, scales, routing data, and all-reduce. These persist across forward passes. |
| `nvfp4_mega_moe.py` | Every forward pass | Orchestrates the MoE: reads from symm buffer → L1 GEMM → activation → re-quantize → L2 GEMM → output. Contains `stage_activation` (BF16→FP4 quantize for L1→L2) and `unpack_ue4m3_u32` (uint32 packed scales → float8). |
| `cutlass_nvfp4_gemm/kernel.py` | Every forward pass (called by nvfp4_mega_moe) | Per-expert loop: gather tokens for each expert, call CUTLASS GEMM, scatter results with routing weights. |
| `cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu` | Every forward pass (CUDA kernel) | The actual CUTLASS kernel: native NVFP4 block-scaled GEMM + GPU-side scale factor remap (row-major → CUTLASS interleaved layout). |
| `cutlass_nvfp4_gemm/sf_layout.py` | Reference only | Documents the CUTLASS SfAtom layout. Not used at runtime (remap is in CUDA). |
---
## Data Formats
### Weights
- **Packed E2M1** (`int8`): 2 FP4 values per byte. Shape: `(E_per_rank, N, K//2)`, K-major layout.
- **UE4M3 block scales** (`float8_e4m3fn`): 1 scale per 16 FP4 values (group_size=16). Shape: `(E_per_rank, N, K//16)`. Returned as `float8_e4m3fn` from `weight_transform.py` — NOT packed uint32. The CUTLASS GEMM consumes float8 directly.
### Activations (after staging kernel)
- **Packed E2M1** (`int8`): Shape: `(num_tokens, K//2)`.
- **UE4M3 scales** (`uint32`): 4 UE4M3 values packed per uint32. Shape: `(num_tokens, K//64)`. Unpacked to `float8_e4m3fn` via `unpack_ue4m3_u32` before reaching the CUTLASS GEMM.
### GEMM dimensions (DeepSeek-V4-Pro)
- **L1 (gate_up_proj):** M×6144×7168 (per expert)
- **L2 (down_proj):** M×7168×3072 (per expert)
- 48 experts per rank (384 total / 8 ranks), top-6 routing
---
## CUTLASS Scale Factor Remap
CUTLASS's `Sm1xxBlockScaledConfig` expects scale factors in a specific interleaved layout, not simple row-major. The SfAtom is:
```
Atom Shape: Shape<Shape<32, 4>, Shape<16, 4>>
Atom Stride: Stride<Stride<16, 4>, Stride<0, 1>>
Tiling: Step<_2, _1> (M tiled with step 2, K with step 1)
```
Our source data is row-major `(M, K_sf)` where `K_sf = K / 16`. The remap kernel (`remap_sf_to_cutlass_kernel` in `cutlass_nvfp4_gemm.cu`) converts from row-major to CUTLASS's interleaved layout.
### How the remap works
The kernel iterates over CUTLASS destination indices, uses `cute::idx2crd` to get the hierarchical coordinate, then `cute::flatten` to get a flat tuple of 8 sub-indices. From those, we extract logical `(m, k_sf)` and read from the row-major source.
### Flattened coordinate decomposition (flat_rank=8)
From the SfAtom layout with Step<_2, _1> tiling, `flatten(idx2crd(idx, ...))` produces 8 values:
```
f0 = inner_m (0..31) — varies fastest within M atom
f1 = sub_m (0..3) — second M sub-coordinate
f2 = tile_m (0..) — M tile index
f3 = step_m stride — degenerate (always = sfa_size, not a coordinate)
f4 = sub_k (0..3) — K sub-coordinate within atom
f5 = tile_k (0..) — K tile index
f6 = 0 — unused
f7 = 0 — unused
```
#### Empirical coordinate dump (MN=8192, K_sf=448, T = sfa_size = 58720256)
| idx | f0 | f1 | f2 | f3 | f4 | f5 | f6 | f7 |
| ----- | --- | --- | --- | --- | --- | --- | --- | --- |
| 0 | 0 | 0 | 0 | T | 0 | 0 | 0 | 0 |
| 1 | 0 | 0 | 0 | T | 1 | 0 | 0 | 0 |
| 4 | 0 | 1 | 0 | T | 0 | 0 | 0 | 0 |
| 16 | 1 | 0 | 0 | T | 0 | 0 | 0 | 0 |
| 511 | 31 | 3 | 0 | T | 3 | 0 | 0 | 0 |
| 512 | 0 | 0 | 0 | T | 0 | 1 | 0 | 0 |
| 1024 | 0 | 0 | 0 | T | 0 | 2 | 0 | 0 |
| 2048 | 0 | 0 | 0 | T | 0 | 4 | 0 | 0 |
| 4096 | 0 | 0 | 0 | T | 0 | 8 | 0 | 0 |
| 8192 | 0 | 0 | 0 | T | 0 | 16 | 0 | 0 |
| 65536 | 0 | 0 | 1 | T | 0 | 16 | 0 | 0 |
| 131072 | 0 | 0 | 2 | T | 0 | 32 | 0 | 0 |
#### Extraction formula
CuTe uses "first sub varies fastest" for `Shape<32, 4>`:
```cpp
m = f0 + f1 * 32 + f2 * 128;
k_sf = f4 + f5 * 4;
```
This was verified with 6 independent probes:
| Probe | Source | Expected | Result |
|-------|--------|----------|--------|
| SFA[1, 0] = 2.0 | row 1 changes | ✅ only row 1 | Confirms f0 term |
| SFA[32, 0] = 2.0 | row 32 changes | ✅ only row 32 | Confirms f1*32, rules out f0*4+f1 |
| SFA[128, 0] = 2.0 | row 128 changes | ✅ only row 128 | Confirms f2*128 |
| SFA[0, 1] = 2.0 | row 0 changes (k=1) | ✅ only row 0 | Confirms f4 term |
| SFA[0, 4] = 2.0 | row 0 changes (k=4) | ✅ only row 0 | Confirms f5*4 term |
| SFA[0, 100] = 2.0 | row 0 changes (k=100) | ✅ only row 0 | Confirms tile-overflow range |
#### Why the previous remap was broken
The previous code used `cute::get<0>(flat)` and `cute::get<1>(flat)` to extract (m, k). Since flatten produces `(inner_m, sub_m, tile_m, ...)` in order, `get<0>` and `get<1>` are both **M sub-indices** — they carry no K information. This caused only `k_group=0` to work; all other K-groups were silently mapped to the wrong source offset.
Additionally, the dest buffer must be zero-initialized before remap because CUTLASS pads to tile boundaries (128 × 64), making the dest buffer larger than `M * K_sf`. Unmapped padding slots reading garbage caused sporadic wrong results.
---
## Bugs Found & Fixed
### 1. unpack_ue4m3_u32: value cast vs bit reinterpret
**File:** `nvfp4_mega_moe.py`
**Bug:** `(x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)` converts integer 63 → float8(63.0).
**Fix:** `(x_u32 & 0xFF).to(torch.uint8).view(torch.float8_e4m3fn)` reinterprets bit pattern 0x3F → float8(~0.984).
**Also:** `uint32` lacks CUDA bitwise ops — cast to `int32` first.
**Impact:** Corrupted every activation scale fed to the L1 GEMM. Weight scales were fine (already float8 from weight_transform). "Structured garbage" recipe.
### 2. stage_activation: three independent bugs
**File:** `nvfp4_moe.py`
**Bug A:** `clamp(0, 15)` zeroed every negative value. E2M1 is sign-magnitude 4-bit (bit3=sign, bits2:0=mag).
**Bug B:** Stored `block_max` but divided by `block_max/6.0` → stored scale was 6× too large.
**Bug C:** Uniform 0.5 step doesn't match E2M1 values {0, ±0.5, ±1, ±1.5, ±2, ±3, ±4, ±6} — non-uniform above ±2.
**Fix:** Rewrote with proper nearest-neighbor E2M1 quantization.
**Impact:** Half the L1→L2 activation was zeroed, 6× scale mismatch, quantization noise on top.
### 3. _fold_global_scale: logical_widths branch
**File:** `weight_transform.py`
**Bug:** `logical_widths=[3072, 3072]` caused the function to apply expert 0's scale to gate half and expert 1's scale to up half of ALL experts. All other experts' global scales were discarded.
**Fix:** Removed the `logical_widths` branch entirely. The `else` branch correctly broadcasts each expert's own `(E, 1)` global scale across `(E, N, K//16)`.
### 4. L1 weight interleave removed
**File:** `weight_transform.py`
**Bug:** `_interleave_l1_weights` assumed gate/up were pre-interleaved in groups of 16 and that the kernel used 2CTA UMMA layout. vLLM uses plain concat `[gate; up]` along the output dim, and our CUTLASS kernel uses `ClusterShape<1, 1, 1>`.
**Fix:** Removed entirely. `l1_weight_out = l1_weight.contiguous()`.
### 5. SF remap: idx2crd+flatten coordinate extraction
**File:** `cutlass_nvfp4_gemm.cu`
**Bug:** `cute::flatten(coord)` produces 8 sub-indices (flat_rank=8). `get<0>` and `get<1>` are both M sub-indices (inner_m, sub_m), carrying zero K information. Only k_group=0 worked; all other K-groups were silently wrong.
**Fix:** Correct extraction: `m = f0 + f1*32 + f2*128`, `k_sf = f4 + f5*4`. Zero-init dest buffer before remap.
**Diagnostic trail:** Constant-scale test (all SF=1.0) → cosine 1.0 proved FP4 path was correct. Real scales → cosine 0.83 proved SF remap was broken. Single-element probes (SFA[0,0] vs SFA[0,3]) proved only k_group=0 worked. Printf dump of flat coordinates at specific indices revealed flat_rank=8 and the correct extraction formula.
### Diagnostic: constant-scale test (smoking gun for SF bugs)
When all scale factors are set to UE4M3(1.0):
- **Cosine = 1.0000, MSE = 0.19** (expected FP4 quantization noise)
With real (variable) scale factors and the broken remap:
- **Cosine = 0.83** → scales are misaligned, not fundamentally broken
After the fix with correct coordinate extraction:
- **Cosine = 1.0000, MSE = 0.0** → perfect match with dequantized reference
---
## Build & Deploy (B200)
```bash
# On B200 host — CUTLASS must be cloned and mounted
cd /root/nvidia-meeting/deepseek-v4-quant/
# Rebuild container (CUTLASS is host-mounted at /root/cutlass)
KERNEL_CACHE_BUSTER=$(date +%s) docker compose build --no-cache
docker compose up -d
```
The CUTLASS extension builds inside the container during `pip install` of the nvfp4-megamoe-kernel package. It needs:
- CUDA 13.0 toolkit (in the vllm/vllm-openai:nightly image)
- CUTLASS headers at `/root/cutlass/include/`
- CCCL headers at `/usr/local/cuda-13.0/targets/x86_64-linux/include/cccl/`
- Device with SM100 compute capability (B200)
---
## Known Issues
1. **MoE dispatch is slow**`cutlass_grouped_nvfp4_gemm` uses a Python loop over 48 experts with per-token scatter/gather. Needs a proper grouped GEMM or at least CUDA-side dispatch.
2. **stage_activation is Python** — Re-quantization from L1 BF16 output to FP4 for L2 input runs in PyTorch. Should use the Triton staging kernel for speed and consistency with vLLM's built-in staging.
3. **SF remap allocates every call**`cudaMemset` + remap kernel runs per GEMM invocation. Could pre-compute the CUTLASS-layout buffer once during weight transform.
---
## Environment Variables
| Variable | Default | Description |
|----------|---------|-------------|
| `MEGA_MOE_STATIC` | 0 | Set to 1 to skip MoE kernel entirely (return zeros) |
| `MEGA_MOE_DEBUG` | 0 | Set to 1 for verbose logging |
| `SKIP_ATTENTION` | 0 | Skip attention layers (debug) |
---
## Repos
- **Kernel:** `sweetapi.com/biondizzle/nvfp4-megamoe-kernel` (branch: master)
- **Deployment:** `sweetapi.com/biondizzle/deepseek-v4-quant` (branch: modelopt-nvfp4)
- **Local:** `~/dev/nvfp4-mojo/`, `~/dev/deepseek-v4-quant/`