Dockerfile: - Removed: C++ CUTLASS extension build, TileLang install, CUTLASS clone - Added: nvidia-cutlass-dsl==4.5.0 install, cutedsl/ copy - Copy nvfp4_cutedsl.py to vllm models dir - Verify step checks cutlass import docker-compose.yml: - Removed stale env vars (MEGA_MOE_DEBUG, MEGA_MOE_STATIC, etc.) deepseek_v4.py: - Fix import: vllm.nvfp4_cutedsl → vllm.model_executor.models.nvfp4_cutedsl README.md: - Updated results: 0% weight loss confirmed (bit-identical view-cast) - 1.1% cosine loss is entirely from activation quantization
134 lines
7.1 KiB
Markdown
134 lines
7.1 KiB
Markdown
# NVFP4 MegaMoE Kernel
|
||
|
||
NVFP4 block-scaled Mixture-of-Experts kernel for DeepSeek-V4 on NVIDIA Blackwell (SM100). Uses CuTeDSL — NVIDIA's Python-based CUTLASS DSL — for a native NVFP4 pipeline that takes full advantage of Blackwell's TMA, MMA, and epilogue overlap.
|
||
|
||
## What This Is
|
||
|
||
A fused MoE FFN kernel that runs the entire expert forward pass in NVFP4:
|
||
|
||
```
|
||
BF16 input → quantize to NVFP4
|
||
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
|
||
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
|
||
Re-quantize → NVFP4
|
||
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
|
||
Scatter with routing weights → BF16 output
|
||
```
|
||
|
||
Both GEMMs are fully NVFP4: A and B in `float4_e2m1fn_x2`, block scales in `float8_e4m3fn`, global scales in `float32`. BF16 is used only for the SiLU activation and the final scatter — the minimum possible.
|
||
|
||
## How We Got Here
|
||
|
||
### The C++ CUTLASS Kernel Was Broken
|
||
|
||
The original kernel was a C++ `.cu` file using CUTLASS's C++ API directly. It passed all the simple tests (uniform data → exact output, SF remap verifier → 0 errors) but produced **cosine 0.05** with real random data. After weeks of debugging the SF remap (8+ iterations, all producing the same 0.2 cosine against a wrong reference), we discovered:
|
||
|
||
1. **The BF16 reference comparison was wrong** — our Python dequantization didn't match CUTLASS's internal FP4 handling. A wrong reference is worse than no reference. We chased ghosts through 8+ SF remap rewrites because the 0.2 cosine was never about the remap.
|
||
|
||
2. **The C++ CUTLASS kernel misinterpreted FP4 data** — even with SF remap verified correct (0 byte errors), the GEMM produced garbage with non-uniform data. The issue was in how CUTLASS's C++ API handles FP4 packing/tiling internally — something we couldn't easily debug or fix.
|
||
|
||
3. **The checkpoint `input_scale` was a red herring** — we tried using the checkpoint's calibration scale as the activation normalization scale. It saturated all block scales to 448.0 (max float8). The `input_scale` is a calibration constant for alpha computation, not a normalization scale.
|
||
|
||
### The CuTeDSL Kernel Works
|
||
|
||
NVIDIA's CuTeDSL approach (Python-based CUTLASS kernels compiled via MLIR → PTX) is what the CUTLASS team recommends for Blackwell. Their official MoE scaled grouped GEMM example (`torch_scaled_grouped_mm.py`) supports NVFP4 out of the box. We adapted it.
|
||
|
||
**Results with real DeepSeek-V4 layer 0 weights:**
|
||
- L1 GEMM alone: cosine 0.995
|
||
- Full MoE pipeline (L1→SiLU→L2→scatter): cosine 0.989
|
||
- Weight loading: **0% loss** — direct uint8→float4_e2m1fn_x2 view-cast, bit-identical to checkpoint
|
||
- Activation quantization: ~1.1% cosine loss (dynamic BF16→NVFP4 — inherent to the format, unavoidable)
|
||
- GEMM kernel: 0% loss (CuTeDSL is correct)
|
||
|
||
The 0.989 cosine is entirely from activation quantization. The weights are bit-identical to the checkpoint — no BF16 round-trip, no precision loss.
|
||
|
||
### Key Lessons
|
||
|
||
1. **A wrong reference is worse than no reference** — the 0.2 cosine against a broken BF16 dequant sent us chasing SF remap bugs for weeks
|
||
2. **The C++ CUTLASS API is a footgun for FP4** — CuTeDSL handles tensor layouts, tiling, and SF construction correctly by construction
|
||
3. **Test with real data early** — uniform tests pass even with broken kernels; random data reveals real bugs
|
||
4. **Separate the GEMM from the pipeline** — our `layertest.py` runs without vLLM, Docker, or tensor parallelism. It caught the kernel bug that vLLM's integration layers masked.
|
||
|
||
## Project Structure
|
||
|
||
```
|
||
nvfp4-megamoe-kernel/
|
||
├── cutedsl/ # CuTeDSL kernel + bridge layer
|
||
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
|
||
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
|
||
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel (untouched)
|
||
│ ├── torch_scaled_grouped_mm.py # The working kernel (3900 lines)
|
||
│ ├── moe_utils.py
|
||
│ ├── moe_persistent_scheduler.py
|
||
│ └── moe_sched_extension.py
|
||
├── src/nvfp4_megamoe_kernel/ # OLD Python pipeline (being replaced)
|
||
│ ├── nvfp4_mega_moe.py # Old pipeline — calls broken C++ kernel
|
||
│ └── cutlass_nvfp4_gemm/ # OLD C++ CUTLASS extension (BROKEN)
|
||
├── vllm/ # vLLM integration
|
||
│ └── patches/
|
||
│ └── deepseek_v4.py # DeepSeek-V4 model patch
|
||
├── tests/
|
||
│ ├── layertest.py # Layer 0 comparison: CuTeDSL vs BF16 (✅ cosine 0.989)
|
||
│ ├── test_cutedsl.py # Small standalone CuTeDSL test (✅ cosine 0.991)
|
||
│ ├── test_uniform_fp4.py # Uniform data GEMM test
|
||
│ ├── test_b_layout.py # B matrix column layout test
|
||
│ └── test_quick_rand.py # Quick random GEMM sanity check
|
||
├── reference/ # Reference files for study
|
||
└── REWRITE_PLAN.md # Original rewrite plan
|
||
```
|
||
|
||
## The Bridge Layer (`cutedsl/bridge.py`)
|
||
|
||
Handles all tensor layout conversion from our pipeline to what the CuTeDSL kernel expects:
|
||
|
||
| Function | What it does |
|
||
|----------|--------------|
|
||
| `quantize_to_nvfp4()` | BF16 → float4_e2m1fn_x2 + float8_e4m3fn block scales + float32 global scale |
|
||
| `quantize_weight_to_nvfp4()` | Same, but for weight matrices with K as the packed dimension |
|
||
| `assemble_scales_2d_side()` | Pad and swizzle activation scale factors (2Dx3D A side) |
|
||
| `assemble_scales_3d_side()` | Pad and swizzle weight scale factors (2Dx3D B side) |
|
||
| `make_b_k_major()` | Convert B tensor from N-major to K-major strides (required by kernel) |
|
||
| `compute_expert_offsets()` | Compute cumulative token offsets for grouped GEMM |
|
||
| `run_nvfp4_grouped_gemm()` | Full kernel launch (compile + run) |
|
||
|
||
## Running Tests
|
||
|
||
On the B200:
|
||
|
||
```bash
|
||
cd /root/nvfp4-megamoe-kernel/tests
|
||
source .venv/bin/activate
|
||
|
||
# Small standalone test
|
||
python3 test_cutedsl.py
|
||
|
||
# Full layer 0 comparison with real weights
|
||
python3 layertest.py
|
||
```
|
||
|
||
## Plan
|
||
|
||
### Phase 1: Kernel ✅ DONE
|
||
- CuTeDSL ScaledGroupedGemmKernel works with NVFP4
|
||
- Bridge layer handles all tensor layout conversion
|
||
- Full MoE pipeline (L1→SiLU→L2→scatter) produces cosine 0.989 vs BF16
|
||
|
||
### Phase 2: vLLM Integration (IN PROGRESS)
|
||
- Wire `cutedsl/moe_pipeline.py` into the vLLM DeepSeek-V4 model
|
||
- Replace `nvfp4_mega_moe_full()` call with `CuTeDSLMoERunner.run()`
|
||
- Weight loading: checkpoint uint8 → float4_e2m1fn_x2 view-cast (bit-preserving, no BF16 round-trip)
|
||
- Block scales (float8_e4m3fn) and global scales (float32) pass through directly from checkpoint
|
||
- L1 dual global scale handling: normalize to max(gate_gs, up_gs), fold ratio into block scales
|
||
- Remove C++ CUTLASS extension build from Dockerfile
|
||
- Add CuTeDSL dependency to the Docker build
|
||
|
||
### Phase 3: Optimization
|
||
- Explore larger tile sizes for better occupancy
|
||
- Profile end-to-end inference on full model
|
||
|
||
### Phase 4: Production
|
||
- Clean up debug artifacts
|
||
- Remove old C++ kernel code
|
||
- Add proper error handling and logging
|
||
- Benchmark vs BF16 baseline
|