259 lines
11 KiB
Markdown
259 lines
11 KiB
Markdown
# DeepSeek V4 Pro → NVFP4 Quantization + vLLM Serving
|
||
|
||
Full NVFP4 quantization of DeepSeek V4 Pro and vLLM serving on 8× NVIDIA B200 GPUs.
|
||
|
||
## Quick Status
|
||
|
||
| Component | Status |
|
||
|-----------|--------|
|
||
| NVFP4 Quantization | ✅ 881GB (Run 11), modelopt 0.45.0.dev64 |
|
||
| Weight Loading | ✅ 95 safetensors shards, all 8 TP ranks |
|
||
| NVFP4→FP8 Conversion (wo_a) | ✅ DeepGEMM block-scale format |
|
||
| NVFP4→BF16 Dequantization | ✅ 305 attn/shared, 91 compressor layers |
|
||
| Compressor Reconstruction | ✅ Separate kv_proj/gate_proj → fused_wkv_wgate |
|
||
| MoE Expert Serving | ✅ FusedMoE NVFP4 (FLASHINFER_TRTLLM backend) |
|
||
| Profile/Warmup Run | ✅ Passes |
|
||
| API Server | ✅ Running on port 8000 |
|
||
| Output Quality | 🔧 Garbled — likely remaining dequant/scale bug |
|
||
|
||
## B200 Node
|
||
|
||
- **IP**: `45.76.247.107`
|
||
- **User**: `root`
|
||
- **Password**: see `.env`
|
||
- **GPUs**: 8× NVIDIA B200 (SM100)
|
||
- **RAM**: ~2.7 TB
|
||
- **Model weights**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4/`
|
||
- **BF16 reference**: `/root/nvidia-meeting/DeepSeek-V4-Pro-BF16/`
|
||
|
||
## Architecture
|
||
|
||
```
|
||
DeepSeek V4 Pro (1.2T params, 61 layers)
|
||
├── MLA Attention (61 layers)
|
||
│ ├── fused_wqa_wkv → BF16 (UnquantizedLinearMethod)
|
||
│ ├── wo_a → FP8 (DeepGEMM block-scale, BMM einsum)
|
||
│ ├── wo_b → BF16 (UnquantizedLinearMethod)
|
||
│ └── compressor.fused_wkv_wgate → BF16 (reconstructed from NVFP4)
|
||
├── MoE Experts (384 experts, 61 layers)
|
||
│ ├── w13_weight → NVFP4 (FusedMoE, FLASHINFER_TRTLLM backend)
|
||
│ └── w2_weight → NVFP4 (FusedMoE, FLASHINFER_TRTLLM backend)
|
||
└── Shared Expert → FP8 (Fp8LinearMethod, DeepGEMM)
|
||
```
|
||
|
||
## The NVFP4 → vLLM Gap
|
||
|
||
ModelOpt quantizes to NVFP4 (4-bit FP4 with block scales). vLLM's DeepSeek V4
|
||
attention code expects FP8 with DeepGEMM block-scale einsum. These formats were
|
||
**never integrated** — we're ahead of NVIDIA on this. Key gaps we had to bridge:
|
||
|
||
### 1. wo_a: NVFP4 → FP8 + DeepGEMM Block Scale
|
||
|
||
**Problem**: `wo_a` uses `deepseek_v4_fp8_einsum` (BMM with DeepGEMM), which expects:
|
||
- Weight: `float8_e4m3fn` in 3D shape `(g, r, d)` for batched matmul
|
||
- Scale: DeepGEMM-formatted block scale tensor (not a per-tensor scalar)
|
||
|
||
Our NVFP4 weights are uint8 packed FP4 with separate block/global scales.
|
||
|
||
**Solution** (`_convert_nvfp4_to_fp8`):
|
||
1. Unpack NVFP4 uint8 → BF16 using E2M1 lookup table
|
||
2. Dequantize: `weight_bf16 * block_scale * global_scale` (NO input_scale — it's for activations)
|
||
3. Re-quantize BF16 → FP8 e4m3 with per-tensor scale (`w_amax / fp8_max`)
|
||
4. Create block scale tensor filled with `fp8_scale` (same scale for every 128×128 block)
|
||
5. Call `deepgemm_post_process_fp8_weight_block(wq, ws, quant_block_shape=(128,128), use_e8m0=True, is_bmm=True, bmm_batch_size=N)`
|
||
6. Store: `weight_scale_inv = dg_ws` (DeepGEMM-formatted scale), `weight = w_fp8` (3D BMM shape)
|
||
|
||
**Why `weight_scale_inv`?** The attention forward reads `self.wo_a.weight_scale_inv` as
|
||
`b_scale` for `deepseek_v4_fp8_einsum` → DeepGEMM `fp8_einsum`. This must be the
|
||
DeepGEMM block-scale tensor, not a per-tensor scalar.
|
||
|
||
**Why `fp8_scale` in the block scale (not all-ones)?** DeepGEMM divides by the block
|
||
scale at runtime. If the block scale is all-ones, it divides by 1.0, producing garbage.
|
||
Each block needs the actual per-tensor scale value.
|
||
|
||
### 2. Attention Layers: NVFP4 → BF16
|
||
|
||
**Problem**: `fused_wqa_wkv`, `wo_b` use standard `torch.nn.functional.linear`.
|
||
NVFP4 weights (uint8) can't be used directly.
|
||
|
||
**Solution** (`_convert_nvfp4_to_bf16`):
|
||
1. Unpack NVFP4 → BF16
|
||
2. Dequantize with block/global scales (input_scale is for activations, not weights)
|
||
3. Replace `mod.weight` with BF16 parameter
|
||
4. Set `quant_method = UnquantizedLinearMethod()`
|
||
5. Remove NVFP4 scale attributes (`weight_scale`, `weight_scale_2`, `input_scale`)
|
||
|
||
### 3. Compressor: Reconstructing fused_wkv_wgate from NVFP4
|
||
|
||
**Problem**: The compressor's `fused_wkv_wgate` is a `MergedColumnParallelLinear`
|
||
with `disable_tp=True`. NVFP4 uint8 data can't be loaded into the BF16 parameter
|
||
(shape mismatch: uint8 is half the input dim). The default weight loader silently
|
||
skips these weights, leaving the parameter uninitialized.
|
||
|
||
**Solution** (`_reconstruct_compressor_weight`):
|
||
1. Read original `kv_proj.weight` and `gate_proj.weight` directly from safetensors
|
||
2. Unpack NVFP4 → BF16, dequantize with scales
|
||
3. Concatenate: `fused = cat([wkv, wgate], dim=0)`
|
||
4. Replace the uninitialized parameter
|
||
|
||
**Critical detail**: The **indexer** compressor is at a different checkpoint path:
|
||
- Main: `model.layers.N.self_attn.compressor.{kv_proj,gate_proj}.weight`
|
||
- Indexer: `model.layers.N.self_attn.compressor.indexer.{kv_proj,gate_proj}.weight`
|
||
|
||
Using the wrong prefix loads the main compressor weight into the indexer's
|
||
`fused_wkv_wgate`, causing a 4× shape mismatch and `split_with_sizes` crash.
|
||
|
||
### 4. MoE Experts: NVFP4 FusedMoE
|
||
|
||
**Problem**: vLLM's DeepSeek V4 uses `DeepseekV4MegaMoEExperts` with DeepGEMM
|
||
grouped GEMM. NVFP4 experts need a different kernel path.
|
||
|
||
**Solution**: The existing `ModelOptNvFp4LinearMethod` + `FusedMoE` infrastructure
|
||
handles NVFP4 experts natively. We just need to:
|
||
- Keep expert weights as NVFP4 uint8 + block/global scales
|
||
- Use `FLASHINFER_TRTLLM` MoE backend (auto-selected)
|
||
- Skip any conversion in `process_weights_after_loading`
|
||
|
||
### 5. BF16 wo_a Layers: BF16 → FP8
|
||
|
||
**Problem**: Some `wo_a` layers were NOT quantized by modelopt (BF16 in checkpoint).
|
||
The attention forward still reads them as FP8 for the einsum path.
|
||
|
||
**Solution** (`_convert_bf16_to_fp8`): Same as #1 but skip the NVFP4 unpack step.
|
||
Directly quantize BF16 → FP8 with block scale.
|
||
|
||
## Bugs Found and Fixed
|
||
|
||
### DeepGEMM `sf.dim()` Assertion (layout.hpp:94)
|
||
- **Root cause**: `weight_scale_inv` was a 1D per-tensor scale `(g,)`. DeepGEMM expects
|
||
2D/3D block-scale tensor formatted by `transform_sf_into_required_layout`.
|
||
- **Fix**: Use `deepgemm_post_process_fp8_weight_block` to produce correctly formatted
|
||
block scales, store result in `weight_scale_inv`.
|
||
|
||
### Block Scale dtype (`float8_e4m3fn` vs `float32`)
|
||
- **Root cause**: `deepgemm_post_process_fp8_weight_block` expects `float32` or
|
||
`float8_e8m0fnu` block scales. We initially used `float8_e4m3fn`.
|
||
- **Fix**: Create block scale as `dtype=torch.float32`.
|
||
|
||
### Missing `deepgemm_post_process` args
|
||
- **Root cause**: Function signature changed to require `quant_block_shape` and `use_e8m0`.
|
||
- **Fix**: Pass `quant_block_shape=(128, 128)` and `use_e8m0=True`.
|
||
|
||
### Compressor Indexer Shape Mismatch
|
||
- **Root cause**: `_reconstruct_compressor_weight` used the same checkpoint prefix
|
||
for both main and indexer compressors. The indexer's keys have `.indexer.` in the path.
|
||
- **Fix**: Add `sub_path` parameter; pass `".indexer"` for indexer compressors.
|
||
|
||
### All-Ones Block Scale → Garbage Output
|
||
- **Root cause**: Block scale was `torch.ones(...)` (scale=1.0). DeepGEMM divides by
|
||
the block scale at runtime, so the output was divided by 1.0 instead of the actual
|
||
per-tensor scale, producing incoherent text.
|
||
- **Fix**: Use `torch.full(..., fp8_scale.item())` to fill the block scale with the
|
||
correct per-tensor FP8 quantization scale.
|
||
|
||
## Running
|
||
|
||
```bash
|
||
# On B200 node
|
||
cd /root/nvidia-meeting
|
||
docker compose up -d
|
||
|
||
# Check logs
|
||
docker logs -f nvidia-meeting-vllm-1
|
||
|
||
# Test
|
||
curl http://localhost:8000/v1/models
|
||
curl http://localhost:8000/v1/chat/completions \
|
||
-H "Content-Type: application/json" \
|
||
-d '{"model": "/model", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||
```
|
||
|
||
## Files
|
||
|
||
| File | Purpose |
|
||
|------|---------|
|
||
| `patches/deepseek_v4.py` | Main patch: NVFP4 post-load conversion, weight reconstruction, DeepGEMM block-scale |
|
||
| `patches/modelopt.py` | ModelOpt FP4 config patches for weight loading |
|
||
| `.env` | B200 node credentials |
|
||
| `docker-compose.yml` | Container config (8 GPU, TP=8, EP=8, NVFP4 quant) |
|
||
|
||
## Conversion Flow
|
||
|
||
```
|
||
Checkpoint (NVFP4 safetensors)
|
||
│
|
||
├── [weight loader] ──→ vLLM model (NVFP4 uint8 params)
|
||
│
|
||
└── [process_weights_after_loading]
|
||
├── wo_a (is_bmm=True):
|
||
│ NVFP4→BF16→FP8 + DeepGEMM block scale
|
||
│ weight_scale_inv = dg_ws, weight = 3D FP8
|
||
│
|
||
├── fused_wqa_wkv, wo_b, shared_expert:
|
||
│ NVFP4→BF16, UnquantizedLinearMethod
|
||
│
|
||
├── compressor.fused_wkv_wgate:
|
||
│ Read kv_proj+gate_proj from checkpoint
|
||
│ NVFP4→BF16, cat into fused weight
|
||
│
|
||
└── MoE experts: stay NVFP4 (FusedMoE backend)
|
||
```
|
||
|
||
## Bugs Found and Fixed (continued)
|
||
|
||
### `input_scale` Multiplied into Weight Dequantization (CRITICAL)
|
||
- **Root cause**: `_convert_nvfp4_to_bf16`, `_convert_nvfp4_to_fp8`, and
|
||
`_reconstruct_compressor_weight` all multiplied by `input_scale` during weight
|
||
dequantization. `input_scale` is for **activations**, not weights. The correct
|
||
formula is: `weight_bf16 = e2m1 * block_scale * global_scale` (NO input_scale).
|
||
Including it made weights ~5000× too small, causing garbage output.
|
||
- **Fix**: Removed `* input_scale` from all three dequant paths.
|
||
|
||
### `fused_skip_regex` Skipping Non-Fused Layer Scales (CRITICAL)
|
||
- **Root cause**: The skip list included `q_b_proj`, `o_a_proj`, `o_b_proj` weight
|
||
scales. These are **NOT fused/stacked** — they're individual Linear layers
|
||
(`wq_b`, `wo_a`, `wo_b`) converted in-place. Skipping their scales caused
|
||
`process_weights_after_loading` to read `torch.empty()` garbage for
|
||
`weight_scale_inv`, producing garbled output.
|
||
- **Fix**: Removed `q_b_proj`, `o_a_proj`, `o_b_proj` scale entries from
|
||
`fused_skip_regex`. Only truly stacked params remain skipped:
|
||
`compressor.{kv_proj,gate_proj}` → `fused_wkv_wgate`,
|
||
`self_attn.{kv_proj,q_a_proj}` → `fused_wqa_wkv`,
|
||
`shared_experts.{gate_proj,up_proj}` → `gate_up_proj`.
|
||
|
||
## Version Banner
|
||
|
||
The patch prints a version banner at import time (visible in `docker logs`):
|
||
```
|
||
======================================================================
|
||
DeepSeek V4 NVFP4 Patch
|
||
Commit: 26aaaba
|
||
Loaded: 2026-05-11 04:25:00 UTC
|
||
Node: ...
|
||
|
||
Architecture: ...
|
||
Bugs fixed: #1-#6
|
||
======================================================================
|
||
```
|
||
This ensures you can always verify what's running inside the container.
|
||
|
||
## Known Issues
|
||
|
||
1. **Output quality**: Model produces tokens but they're garbled/incoherent.
|
||
All 6 known bugs are fixed. The remaining issue is under investigation —
|
||
likely a subtle dequantization bug (sign handling, scale ordering, or
|
||
E2M1 unpack edge case). The version banner in the logs helps debug which
|
||
patch version is active.
|
||
|
||
2. **Runtime performance**: Not yet benchmarked. The DeepGEMM einsum + FusedMoE
|
||
path should be efficient on B200, but the BF16 layers go through
|
||
`UnquantizedLinearMethod` which may be slower than dedicated kernels.
|
||
|
||
## Quantization Details
|
||
|
||
- **Model**: DeepSeek V4 Pro (1.2T parameters)
|
||
- **Format**: NVIDIA NVFP4 (4-bit floating point with 128-element block scales)
|
||
- **Tool**: modelopt 0.45.0.dev64 + transformers 5.8.0.dev0
|
||
- **Run**: Run 11 (881GB), 8× B200, ~$161/run
|
||
- **Checkpoint**: 95 safetensors shards
|