- Fix _ue8m0_to_float32: checkpoint is float8_e4m3fn (UE4M3), not UE8M0 - Changed from shift-by-23 to .to(torch.float32) in both copies - Fix fold_global_scale in DeepGEMM mega/__init__.py - Fix staging kernel SF pack: int32 shift >= 32 is UB on GPU - Split 8-group pack into two int32 writes (groups 0-3, 4-7) - Fix staging kernel E2M1 output: was writing unpacked (1 byte/elem) into packed buffer (hidden/2 bytes), causing 2x overflow - Now packs even/odd nibble pairs correctly - Fix wo_a on-the-fly BF16→NVFP4: was encoding UE8M0, now UE4M3 - Use .clamp(0, 448).to(float8_e4m3fn) instead of log2/exp trick - Remove dead code: _ue8m0_uint8_to_float, tmp/, .bak, .s11, quant_module_patched.py, patch_finegrained_fp8_blackwell.py, patch_vllm_weights.py - Remove SCALE-FMT diagnostic histogram clutter - Update stale UE8M0 comments throughout - Rewrite README: clean instructions, confirmed format details
8.2 KiB
DeepSeek V4 Pro → NVFP4 Quantization + vLLM Serving
Full NVFP4 quantization of DeepSeek V4 Pro and vLLM serving on 8× NVIDIA B200 GPUs.
Quick Status
| Component | Status |
|---|---|
| NVFP4 Quantization | ✅ 881GB (Run 11), modelopt 0.45.0.dev64 |
| Weight Loading | ✅ 95 safetensors shards, all 8 TP ranks |
| Dequant Verification | ✅ Bit-exact match against official dequant (0.0 relative error) |
| NVFP4→FP8 Conversion (wo_a) | ✅ DeepGEMM block-scale format |
| NVFP4→BF16 Dequantization | ✅ 305 attn/shared, 91 compressor layers |
| Compressor Reconstruction | ✅ Separate kv_proj/gate_proj → fused_wkv_wgate |
| MoE Expert Serving (MegaMoE) | 🔧 Kernel builds & runs on sm_100a, debugging illegal CUDA access |
| Output Quality | 🔧 Under investigation |
B200 Node
- IP:
45.76.247.107 - User:
root - Password: see
.env - GPUs: 8× NVIDIA B200 (SM100a)
- Model weights:
/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4/ - BF16 reference:
/root/nvidia-meeting/DeepSeek-V4-Pro-BF16/
Repositories
| Repo | Branch | Purpose |
|---|---|---|
deepseek-v4-quant |
modelopt-nvfp4 |
Main repo: patches, quantize, serve scripts |
DeepGEMM |
nvfp4-mega-moe |
NVFP4 mega_moe kernel fork |
Architecture
DeepSeek V4 Pro (1.2T params, 61 layers)
├── MLA Attention (61 layers)
│ ├── fused_wqa_wkv → BF16 (UnquantizedLinearMethod)
│ ├── wo_a → FP8 (DeepGEMM block-scale, BMM einsum)
│ ├── wo_b → BF16 (UnquantizedLinearMethod)
│ └── compressor.fused_wkv_wgate → BF16 (reconstructed from NVFP4)
├── MoE Experts (256 experts per layer, 61 layers)
│ └── MegaMoE path → NVFP4 (DeepGEMM mxf4nvf4, native block16)
└── Shared Expert → FP8 (Fp8LinearMethod, DeepGEMM)
NVFP4 Format (Confirmed)
| Field | Format | Notes |
|---|---|---|
| Weights | E2M1 packed uint8 | 2 values per byte |
| Block scales | torch.float8_e4m3fn (UE4M3) |
Standard NVFP4 spec, group_size=16 |
| Global scales | torch.float32 (weight_scale_2) |
Scalar per expert (torch.Size([])) |
| Dequant | value = packed_E2M1 * block_scale * global_scale |
Block scale range [0, 448] |
Key finding: The checkpoint stores block scales as torch.float8_e4m3fn (UE4M3), NOT UE8M0.
.to(torch.float32) is the correct conversion. The shift-by-23 trick was wrong — it was
applying an E8M0→float conversion to E4M3 bytes, producing garbage.
Dequant Verification
We verified the dequant path is bit-exact against the official reference:
W_bf16 = dequantize_fp4_weight(W_int, S)
y_ours = W_bf16 @ x.bfloat16()
y_ref = official_expert_forward(W_int, S, x)
print((y_ours - y_ref).abs().max() / y_ref.abs().mean())
Result:
Max abs diff: 0.00000000
Mean abs diff: 0.00000000
Relative error: 0.000000
Matmul max diff: 0.00000000
Running
1. Quantize
# On B200 node, in screen
screen -S quantize
cd /root/nvidia-meeting
bash run_quantize_nvfp4.sh
# ~7 hours, $161 per run
2. Build Container
# From this repo
bash build_push.sh
# Always build in screen: screen -S build
The Dockerfile:
- Extends
atl.vultrcr.com/vllm/vllm-with-lmcache:dream-build - Clones DeepGEMM (
nvfp4-mega-moebranch) and builds - Copies
patches/deepseek_v4.pyover vLLM's model file
3. Serve
# On B200 node
cd /root/nvidia-meeting
docker compose up -d
# Check logs
docker logs -f nvidia-meeting-vllm-1
# Test
curl http://localhost:8000/v1/models
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "/model", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
vLLM Flags
--trust-remote-code
--kv-cache-dtype fp8
--block-size 256
--enable-expert-parallel
--tensor-parallel-size 8
--compilation-config {"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}
--attention_config.use_fp4_indexer_cache=True
--tokenizer-mode deepseek_v4
--tool-call-parser deepseek_v4
--enable-auto-tool-choice
--reasoning-parser deepseek_v4
--speculative_config {"method":"mtp","num_speculative_tokens":2}
NVFP4 Mega MoE Kernel
What We Built
A native NVFP4 mega_moe kernel in our DeepGEMM fork. Weights stay in E2M1 packed format
and use kind::mxf4nvf4.block_scale.scale_vec::4X MMA directly on SM100a (B200).
This is novel — NVIDIA has not done NVFP4→vLLM integration.
Kernel Architecture
| Parameter | Value |
|---|---|
| PTX instruction | tcgen05.mma.kind::mxf4nvf4.block_scale.scale_vec::4X |
| kGranK | 16 (NVFP4 native block_size) |
| Weight format | E2M1 packed uint8 (unchanged from checkpoint) |
| Block scales | UE4M3 (float8_e4m3fn), native — no conversion needed |
| Global scales | Folded into block scales before packing |
| Instruction desc | float_ue4m3_t |
| SF layout | block16, scale_vec::4X |
| UTCCP stride | i*8 (4X layout) |
| kNumSFUint32 | kHidden / 64 (4 UE4M3 per int32) |
| recipe | (1, 1, 16) |
| Target arch | sm_100a (the a suffix is required) |
Python API
fp8_nvfp4_mega_moe()— entry point, recipe=(1,1,16)transform_nvfp4_weights_for_mega_moe()— fold global scales, pack UE4M3→int32, TMA-alignget_symm_buffer_for_nvfp4_mega_moe()— 2× SF buffer vs MXFP4
C++ Bindings
csrc/apis/mega_nvfp4.hpp— kGranK=16, SF stride K/16, packed E2M1 hidden/2csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp— host-side TMA descriptorsdeep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh— kernel
Full FP4 Pipeline
The mxf4nvf4 instruction is FP4×FP4 — both activations (A) and weights (B) must be E2M1 packed.
A Triton staging kernel quantizes BF16 activations → E2M1 packed uint8 + UE4M3 block16 scales
before the GEMM. The L1 epilogue outputs UE4M3 activation scales directly (float→e4m3 cast).
Bugs Found and Fixed
| # | Bug | Impact | Fix |
|---|---|---|---|
| 1 | DeepGEMM sf.dim() crash |
Server crash | deepgemm_post_process_fp8_weight_block for block-scale format |
| 2 | Block scale dtype float8_e4m3fn |
Crash | Use float32 for block-scale tensor |
| 3 | Missing deepgemm_post_process args |
Crash | Pass quant_block_shape, use_e8m0 |
| 4 | Compressor indexer shape mismatch | Crash | .indexer. sub-path in checkpoint keys |
| 5 | All-ones block scale | Garbage output | torch.full(..., fp8_scale) not torch.ones |
| 6 | fused_skip_regex skipping q_b/o_a/o_b scales |
Garbage output | Remove non-fused scale entries from skip list |
| 7 | UE8M0 shift-by-23 applied to E4M3 scales | Garbled output | Checkpoint is standard UE4M3 — use .to(float32) (shift-by-23 was wrong) |
| 8 | wo_a BF16→NVFP4 on-the-fly used UE8M0 encoding | Scrambled attention | Produce UE4M3 directly: .clamp(0, 448).to(float8_e4m3fn) |
| 9 | FP8 activations fed to mxf4nvf4 (FP4×FP4 instruction) | Crash/garbled | Full FP4 pipeline: activations are E2M1 packed + UE4M3 scales |
| 10 | Staging kernel SF pack: shift ≥32 is UB | Half the activation scales zeroed | Split into 2 int32 writes per k_block (groups 0-3, 4-7) |
| 11 | Staging kernel wrote unpacked E2M1 (1 byte/elem) into packed buffer | 2× buffer overflow | Pack even/odd nibble pairs, write BLOCK_K//2 bytes |
| 12 | compute-sanitizer build running during debug |
Slow (50×), masking timing | Remove sanitizer, rebuild |
Files
| File | Purpose |
|---|---|
patches/deepseek_v4.py |
Main patch: NVFP4 weight loading, dequant, staging kernel, MegaMoE |
patches/staging_kernel.py |
Reference copy of Triton staging kernel (live copy is in deepseek_v4.py) |
scripts/dequant_fp8_to_bf16.py |
BF16 dequantization utility |
scripts/quantize_nvfp4.py |
NVFP4 quantization runner |
scripts/serve_vllm.py |
Standalone vLLM server launcher |
Dockerfile |
Container build (extends dream-build with DeepGEMM + patch) |
docker-compose.yml |
Production serve config |
build_push.sh |
Build, push to CR, update docker-compose |
HARD RULES
- NEVER convert DeepSeek MoE experts to MXFP4. Experts stay in NVFP4. Period.
- The checkpoint is UE4M3 (float8_e4m3fn), NOT UE8M0. Never use shift-by-23 on these bytes.
- Target
sm_100a, notsm_100. Theasuffix is required for mxf4nvf4 instructions.