181 lines
9.5 KiB
Markdown
181 lines
9.5 KiB
Markdown
# NVFP4 MegaMoE Kernel
|
||
|
||
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead.
|
||
|
||
## ⚠️ READ THIS FIRST — THE #1 RULE
|
||
|
||
**YOU MUST BUILD YOUR OWN KERNELS. ALL OF THEM. DO NOT PATCH vLLM.**
|
||
|
||
Mike was right — we need our own kernels. Not just for the NVFP4 GEMMs, but for the **ENTIRE attention pipeline**. The current approach of patching individual vLLM functions is a house of cards. Every patch leads to another crash, every workaround reveals three more broken things. FlashMLA, fp8_ds_mla, the fused C++ kernels, the Triton compressor, the indexer — they're all deeply coupled. You cannot swap one piece and expect the rest to work.
|
||
|
||
**THE ONLY PATH FORWARD:**
|
||
1. Build CuTeDSL kernels for EVERYTHING — attention, KV cache, RoPE, the whole stack
|
||
2. Test each kernel standalone on the B200 venv BEFORE touching the container
|
||
3. Wire them together into a proper vLLM attention backend
|
||
4. THEN and ONLY THEN test in the container
|
||
|
||
**DO NOT:**
|
||
- ❌ Try to patch vLLM's FlashMLA code to "work" on Blackwell
|
||
- ❌ Use pure PyTorch as a "temporary workaround" — it produces garbage
|
||
- ❌ Skip the KV cache write and hope for the best
|
||
- ❌ Assume you can mix our kernels with vLLM's existing attention backend
|
||
- ❌ Touch the container until ALL kernels pass standalone tests
|
||
|
||
**DO:**
|
||
- ✅ Build CuTeDSL kernels in `cutedsl/`
|
||
- ✅ Test each one in `tests/` on the B200 venv
|
||
- ✅ Compare against BF16 reference (cosine >= 0.98 or it's broken)
|
||
- ✅ Wire them into a proper attention backend class
|
||
- ✅ Only test in the container once everything passes standalone
|
||
|
||
---
|
||
|
||
## What This Is
|
||
|
||
A native NVFP4 inference stack for DeepSeek-V4:
|
||
|
||
**MoE Experts** — CuTeDSL ScaledGroupedGemmKernel ✅:
|
||
```
|
||
BF16 input → quantize to NVFP4
|
||
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
|
||
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
|
||
Re-quantize to NVFP4
|
||
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
|
||
Scatter with routing weights → BF16 output
|
||
```
|
||
|
||
**Attention Projections** — CuTeDSL NVFP4 GEMM ✅:
|
||
- `q_a_proj`, `q_b_proj`, `kv_proj`, `wo_b_proj` — native NVFP4, cosine 0.995 vs BF16
|
||
- `wo_a` — BF16 BMM (o_a_proj weights are BF16 in checkpoint)
|
||
- All verified with `tests/test_full_layer_b200.py`
|
||
|
||
**Shared Experts** — CuTeDSL NVFP4 GEMM ✅:
|
||
- `gate_up_proj`, `down_proj` — native NVFP4, cosine 0.990 vs BF16
|
||
|
||
**Attention Pipeline** — ✅ Verified standalone, 🔧 vLLM integration blocked by NaN:
|
||
- KV cache write (RoPE → fp8 quant → paged cache) — cosine 0.999
|
||
- KV cache read (paged cache → fp8 dequant → BF16) — cosine 0.999
|
||
- Decode attention (1 query vs N cached KVs) — cosine 0.9998
|
||
- Full pipeline (inv RoPE + o_a BMM + o_b) — cosine 0.996–0.999
|
||
- All 5 layer types (C128A, C4A, SWA) — cosine ≥0.996
|
||
|
||
## Architecture: DeepSeek-V4-Pro
|
||
|
||
**MegaMoE (384 experts, top-6) with CSA + HCA + mHC:**
|
||
|
||
- **CSA (Compress Ratio 4)**: Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
|
||
- **HCA (Compress Ratio 128)**: Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build.
|
||
- **mHC**: Manifold-Constrained Hyper-Connections — replaces standard residual connections. Learned mixing with Sinkhorn normalization.
|
||
- **SWA**: Sliding Window Attention — local window (compress_ratio=0, last layer only)
|
||
|
||
**Compress Ratios (from config.json):**
|
||
```
|
||
Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA)
|
||
Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA)
|
||
```
|
||
|
||
**Expert intermediate size: 3072** (NOT 18432 — that's 6×3072 for top-6)
|
||
|
||
**DeepGEMM MegaMoE**: DeepSeek's persistent grouped GEMM for MoE uses TMA tensormap updates per expert with variable block_m (16-192) based on expected tokens per expert. Our CuTeDSL runner uses `run_nvfp4_grouped_gemm` (simpler, but proven correct in standalone tests).
|
||
|
||
## Current Status
|
||
|
||
### ✅ Verified (B200 venv, real weights, zero NaN)
|
||
|
||
| Component | Test | Cosine vs BF16 |
|
||
|-----------|------|----------------|
|
||
| CuTeDSL NVFP4 Linear (q_a, kv, q_b, wo_b) | `test_full_layer_b200.py` | 0.994+ |
|
||
| CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) | `layertest.py` | 0.988 |
|
||
| FP8 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9997 |
|
||
| Paged KV cache read/write | `test_kv_cache_b200.py` | 1.0 |
|
||
| CSA sparse attention (cr=4) | `test_sparse_attn_b200.py` | works, no NaN |
|
||
| HCA sparse attention (cr=128) | `test_sparse_attn_b200.py` | works, no NaN |
|
||
| Full attention pipeline (all layer types) | `test_v4_attention_b200.py` | 0.981–0.995 |
|
||
| KV cache write + decode attention | `test_decode_attention_b200.py` | 0.9998 |
|
||
| Decode vs prefill consistency (5 layers) | `test_decode_vs_prefill_b200.py` | 0.996–0.999 |
|
||
| E2E 61-layer model (shared experts) | `test_e2e_decode_b200.py` | healthy logits |
|
||
| MoE runner (grouped GEMM, 16 experts) | `test_moe_runner_nan_b200.py` | no NaN, all sizes |
|
||
| Full layer (attention + MoE) | `test_full_layer_nan_b200.py` | no NaN |
|
||
| Multi-layer chain (3 layers) | `test_full_layer_nan_b200.py` | no NaN |
|
||
|
||
### ❌ Container — NaN in vLLM compiled execution
|
||
|
||
The container produces empty/garbage output. Debug logs show NaN in `hidden_states` from the first forward pass. **The NaN is NOT from our kernels** — it comes from vLLM's compiled execution infrastructure (see CURRENT_BUG.md for full investigation).
|
||
|
||
Most likely sources:
|
||
1. `attn_gemm_parallel_execute` — fused parallel GEMM (NOT our CuTeDSL kernel)
|
||
2. `fused_q_kv_rmsnorm` — CUDA kernel that may produce NaN on Blackwell
|
||
3. Weight packing during model loading
|
||
4. `torch.compile` + cudagraph interaction with CuTeDSL buffers
|
||
|
||
### ❌ Does NOT Work
|
||
|
||
- **NVFP4 Q×K^T GEMM** — cosine 0.86, too lossy for attention scores. Keep attention in BF16.
|
||
- **Patching vLLM's FlashMLA path** — house of cards. Don't do it.
|
||
|
||
## Test Files
|
||
|
||
| Test | What it does | Status |
|
||
|------|-------------|--------|
|
||
| `tests/test_full_layer_b200.py` | All NVFP4 projections vs BF16 | ✅ 0.994+ |
|
||
| `tests/layertest.py` | MoE layer test | ✅ 0.988 |
|
||
| `tests/cudagraph_test.py` | CUDAGraph compatibility | ✅ PASS |
|
||
| `tests/test_v4_attention_b200.py` | All 3 layer types (SWA, C128A, C4A) | ✅ 0.981-0.995 |
|
||
| `tests/test_kv_cache_b200.py` | FP8/NVFP4 KV cache + paged cache | ✅ 0.9997 |
|
||
| `tests/test_sparse_attn_b200.py` | CSA/HCA sparse + SWA merged | ✅ works |
|
||
| `tests/test_decode_attention_b200.py` | Prefill + decode with KV cache | ✅ 0.9998 |
|
||
| `tests/test_decode_vs_prefill_b200.py` | Decode vs prefill consistency | ✅ 0.996-0.999 |
|
||
| `tests/test_e2e_decode_b200.py` | 61-layer E2E (shared experts) | ✅ healthy logits |
|
||
| `tests/test_moe_nan_b200.py` | Single expert NaN check | ✅ no NaN |
|
||
| `tests/test_moe_runner_nan_b200.py` | MoE grouped GEMM NaN check | ✅ no NaN |
|
||
| `tests/test_full_layer_nan_b200.py` | Full layer + multi-layer NaN check | ✅ no NaN |
|
||
|
||
## Project Structure
|
||
|
||
```
|
||
nvfp4-megamoe-kernel/
|
||
├── cutedsl/ # CuTeDSL kernel + bridge layer
|
||
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
|
||
│ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
|
||
│ ├── runner.py # CuTeDSLMoERunner — grouped GEMM MoE
|
||
│ ├── blackwell_attention.py # KV cache + attention (standalone, works)
|
||
│ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA)
|
||
│ ├── custom_ops.py # torch.autograd wrappers
|
||
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel
|
||
├── vllm/ # vLLM integration
|
||
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner (vLLM wrapper)
|
||
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod
|
||
│ └── patches/
|
||
│ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
|
||
│ ├── deepseek_compressor.py # Compressor patch (skip fused kernel on Blackwell)
|
||
│ ├── patch_kv_cache_utils.py # KV cache page size fix
|
||
│ ├── patch_swa_cache.py # SWA cache alignment fix
|
||
│ └── layers/
|
||
│ ├── csa_attention.py # BF16 SDPA + KV cache (our Blackwell path)
|
||
│ └── deepseek_compressor.py # Skip fused kernel on Blackwell
|
||
├── tests/ # Standalone tests (run on B200 venv)
|
||
├── Dockerfile # Container build
|
||
├── README.md # This file
|
||
└── CURRENT_BUG.md # Current bug investigation
|
||
```
|
||
|
||
## Plan
|
||
|
||
### Phase 1: MoE Kernel ✅ DONE
|
||
### Phase 2: NVFP4 Linear Kernels ✅ DONE
|
||
### Phase 3: Attention Pipeline ✅ DONE (standalone, all tests pass)
|
||
### Phase 4: vLLM Integration 🔧 IN PROGRESS — blocked by NaN from vLLM infrastructure
|
||
|
||
**Current blocker:** NaN in the vLLM container's compiled execution. Our kernels produce zero NaN standalone. The NaN comes from vLLM's `attn_gemm_parallel_execute` or `fused_q_kv_rmsnorm` CUDA kernels, weight packing, or torch.compile interaction.
|
||
|
||
**Next:**
|
||
1. Install vllm in B200 venv, test the exact parallel GEMM path
|
||
2. Test with torch.compile disabled in the container
|
||
3. Add NaN checks inside the parallel GEMM wrapper
|
||
4. If the parallel GEMM is the source, replace it with our CuTeDSL kernels (path of least resistance)
|
||
|
||
### Phase 5: Production
|
||
- End-to-end benchmarking
|
||
- Optimize tile sizes
|
||
- Clean up
|