vLLM NVFP4 serving: full end-to-end pipeline working

Bridged the gap between ModelOpt NVFP4 and vLLM DeepSeek V4 attention.
Server loads and serves tokens on 8x B200 with TP=8, EP=8.

Key changes:
- wo_a: NVFP4->BF16->FP8 with DeepGEMM block-scale format for BMM einsum
  Uses deepgemm_post_process_fp8_weight_block for correct scale layout
  weight_scale_inv = DeepGEMM-formatted block scale (NOT per-tensor scalar)
  Block scale filled with fp8_scale (NOT all-ones -- causes garbage output)
- Attention: NVFP4->BF16 dequantization, UnquantizedLinearMethod
- Compressor: reconstruct fused_wkv_wgate from separate kv_proj+gate_proj
  Fixed indexer path: compressor.indexer.kv_proj (was loading main compressor)
- MoE experts: stay NVFP4, FLASHINFER_TRTLLM FusedMoE backend

Bugs fixed:
1. DeepGEMM sf.dim() assertion: weight_scale_inv must be block-scale tensor
2. Block scale dtype: float32 (not float8_e4m3fn)
3. Missing deepgemm_post_process args: quant_block_shape, use_e8m0
4. Compressor indexer shape mismatch: wrong checkpoint key prefix
5. All-ones block scale: DeepGEMM divides by 1.0 instead of actual scale

Updated README with full technical documentation of all fixes.
This commit is contained in:
2026-05-11 02:01:46 +00:00
parent db16be8e5d
commit 653e2d7a50
2 changed files with 612 additions and 301 deletions

492
README.md
View File

@@ -1,322 +1,220 @@
# DeepSeek V4 Pro → NVFP4 Quantization + vLLM Serving
Full NVFP4 quantization of DeepSeek V4 Pro on a single B200 node (8× B200, 2.7TB RAM, 13TB NVMe). **Result: 881GB NVFP4 (Run 11).** Now working on vLLM serving of the quantized checkpoint.
Full NVFP4 quantization of DeepSeek V4 Pro and vLLM serving on 8× NVIDIA B200 GPUs.
**Cost:** ~$161/run at $23/hr (7 hours each). Don't waste runs.
## Quick Status
## ✅ Final Quantization Result (Run 11)
| Component | Status |
|-----------|--------|
| NVFP4 Quantization | ✅ 881GB (Run 11), modelopt 0.45.0.dev64 |
| Weight Loading | ✅ 95 safetensors shards, all 8 TP ranks |
| NVFP4→FP8 Conversion (wo_a) | ✅ DeepGEMM block-scale format |
| NVFP4→BF16 Dequantization | ✅ 305 attn/shared, 91 compressor layers |
| Compressor Reconstruction | ✅ Separate kv_proj/gate_proj → fused_wkv_wgate |
| MoE Expert Serving | ✅ FusedMoE NVFP4 (FLASHINFER_TRTLLM backend) |
| Profile/Warmup Run | ✅ Passes |
| API Server | ✅ Running on port 8000 |
| Output Quality | 🔧 Under investigation (FP4 quantization loss + scale tuning) |
- **Output:** `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4` — 881GB, 95 safetensors
- **Config:** `nvfp4` full quantization, 128 calib samples, `kv_cache_qformat=fp8_cast`
- **Total runtime:** ~7,783s (~2h10m end-to-end)
- **Peak GPU mem:** ~163GB per B200
- **Amax snapshots:** 47,696 quantizers, 15.4MB
- **Calibrated state:** 721.4GB (insurance, can re-export with `--export-only`)
- A few experts (11, 83, 100, 112, 254) had uncalibrated amax — weight-derived fallback used (normal for sparse MoE with 256 experts)
## B200 Node
---
## 🔧 vLLM Serving (In Progress)
### Current Status: Debugging weight loading
The modelopt NVFP4 export and vllm have a chain of incompatibilities. We're progressively fixing them. The fundamental problem: **modelopt's NVFP4 quantization format and vllm's DeepSeek V4 serving code were never integrated.** NVIDIA's own published NVFP4 exports (DeepSeek-V3.2, MiniMax-M2.7) don't have these issues because they don't use MLA attention compression or 256-expert MoE — both of which create stacked/fused weight parameters that modelopt doesn't account for.
### Approach: Patched deepseek_v4.py + disabled mega_moe
Instead of runtime monkey-patching (which doesn't propagate to worker processes), we patch the vllm source file directly. The patched `deepseek_v4.py` is mounted into the Docker container and copied over the original before vllm starts.
We also disabled `--moe-backend=deep_gemm_mega_moe` because:
1. The NVFP4 mega_moe kernel doesn't exist yet (NVIDIA hasn't built it)
2. MegaMoE uses MXFP4 block scale format (32-col blocks), but modelopt exports NVFP4 (16-col blocks) — format mismatch
3. MegaMoE doesn't register `weight_scale_2` or `input_scale` params, so those scales would be silently dropped
Instead, we use the standard FusedMoE path with `ModelOptNvFp4FusedMoE`, which natively supports NVFP4 expert weights.
### vLLM Serving Run History
| # | Date | Approach | Result | Root Cause | Fix/Next |
|---|------|----------|--------|------------|----------|
| S1 | May 10 09:34 | `patch_vllm_weights.py` runtime patch + mega_moe | ❌ `UnboundLocalError: name_mapped` | Expert weight names don't match any mapping → `name_mapped` never assigned | Add gate_proj→w1, up_proj→w3, down_proj→w2 mappings |
| S2 | May 10 ~10:30 | Same, added expert rename regexes | ❌ Same error | `DeepseekV4ForCausalLM.hf_to_vllm_mapper` is a **class attribute** set at import time — patching the function doesn't update the cached mapper | Patch the class attribute directly |
| S3 | May 10 ~11:00 | Patched class attr, but workers are separate processes | ❌ Same error in workers | Workers don't inherit in-memory patches — they fork before the patch runs | Patch the source file directly (`deepseek_v4.py`) |
| S4 | May 10 ~11:30 | Direct source file patch + mega_moe | ❌ `KeyError: 'layers.0.mlp.experts.0.w2.weight'` | modelopt uses `mlp`, vllm uses `ffn` internally — missing `.mlp.``.ffn.` mapping | Add substr mapping |
| S5 | May 10 ~12:00 | Added `mlp→ffn` mapping + mega_moe | ❌ `KeyError: 'fused_wkv_wgate.input_scale'` | Compressor fused params don't register `input_scale`/`weight_scale_2` | Add skip patterns for compressor/attention scale tensors |
| S6 | May 10 ~12:30 | Added skip patterns + mega_moe | ❌ Shape mismatch: `w2_weight_scale (7168, 96) vs (7168, 192)` | NVFP4 uses 16-col block scales, mega_moe expects 32-col MXFP4 — format incompatibility | **Abandon mega_moe** — no NVFP4 mega_moe kernel exists |
| S7 | May 10 ~13:00 | Disabled mega_moe, standard FusedMoE | ❌ `fused_wkv_wgate.weight` shape mismatch: param=(1024,7168) bf16, loaded=(512,3584) uint8 | `MergedColumnParallelLinear` creates weight as bf16 (not uint8), but modelopt exports NVFP4 packed uint8. `ModelOptNvFp4Config` only handles `Linear`, not `MergedColumnParallelLinear` | Unpack uint8→bf16 at load time |
| S8 | May 10 ~13:30 | Added E2M1 unpacking for fused weights | ❌ `KeyError: 'fused_wkv_wgate.weight_scale'` | No `weight_scale` param registered for `MergedColumnParallelLinear` (same `ModelOptNvFp4Config` gap) | Skip all NVFP4 scale tensors for stacked/fused attention+compressor params |
| S9 | May 10 ~14:00 | Added weight_scale skip patterns | ❌ `KeyError: 'compressor.kv_norm.weight'` | modelopt puts `kv_norm` under `compressor`, vllm has it at attention level (`attn.kv_norm`) | Add `compressor.kv_norm``kv_norm` mapping |
| S10 | May 10 ~14:15 | Fixed kv_norm mapping | ❌ `KeyError: 'compressor.position_bias'` | modelopt exports params that don't exist in the vllm model | Make loading resilient to unknown params |
### Open Issues (as of May 10 ~16:00 UTC)
1. **MergedColumnParallelLinear + NVFP4 incompatibility** — The core problem. `ModelOptNvFp4Config.create_weights()` only handles `Linear` layers. For `MergedColumnParallelLinear` (used for `fused_wqa_wkv`, `fused_wkv_wgate`, `gate_up_proj`):
- Weight param is created as `ModelWeightParameter` (bf16) instead of `PackedColumnParameter` (uint8)
- `weight_scale`, `weight_scale_2`, `input_scale` params are never registered
- `adjust_shard_indexes_for_packing` applies `packed_factor` to rows, but NVFP4 packs along columns
- Current workaround: unpack uint8→bf16 at load time, skip scale tensors, let `process_weights_after_loading` re-quantize. This loses the calibration-optimized scales for attention/compressor/shared_expert weights.
2. **No NVFP4 mega_moe kernel** — We disabled mega_moe to avoid the format mismatch. Standard FusedMoE with `ModelOptNvFp4FusedMoE` works for expert weights, but loses the mega_moe performance optimization. When NVIDIA builds an NVFP4 mega_moe kernel, we can re-enable it.
3. **Resilient loading needed** — modelopt exports params (e.g., `compressor.position_bias`) that don't exist in the vllm model. Need to skip unknown params gracefully instead of crashing.
4. **Expert `weight_scale_2` handling with FusedMoE** — The standard FusedMoE path registers `w13_weight_scale_2` and `w2_weight_scale_2`, so expert global scales CAN be loaded. This works for experts. The issue is only with the stacked/fused attention params.
### What Each Patch Does
**`patches/deepseek_v4.py`** — Patched vllm source file, copied over the original at container startup. Contains:
- **Regex mappings** (applied first by WeightsMapper):
- Skip `weight_scale`, `weight_scale_2`, `input_scale` for compressor/attention fused params (no stacked param registered)
- Skip `weight_scale`, `weight_scale_2`, `input_scale` for shared expert gate/up projections (stacked into `gate_up_proj`)
- Expert projection rename: `gate_proj→w1`, `up_proj→w3`, `down_proj→w2` (only for `.experts.N.`, not `.shared_experts.`)
- **Substr mappings** (applied after regex):
- Attention: `self_attn→attn.mla_attn` with proper sub-projection names
- `kv_norm` moved from compressor to attention level
- `compressor.kv_proj→compressor.wkv`, `compressor.gate_proj→compressor.wgate`
- `shared_experts.gate_proj→shared_experts.w1`, `shared_experts.up_proj→shared_experts.w3`
- `.mlp.→.ffn.` (modelopt uses `mlp`, vllm uses `ffn`)
- **E2M1 FP4→BF16 unpacking** for stacked params: When a uint8 packed NVFP4 weight is loaded into a bf16 param (MergedColumnParallelLinear), unpack using the E2M1 lookup table
- **Resilient loading**: Skip unknown params that modelopt exports but vllm doesn't have
**`patches/patch_vllm_weights.py`** — Legacy runtime monkey-patch approach. Doesn't work because vllm workers are separate processes that don't inherit in-memory patches. Kept for reference.
**`docker-compose.yml`** — Docker Compose config:
- Copies patched `deepseek_v4.py` before starting vllm
- Removed `--moe-backend=deep_gemm_mega_moe` (no NVFP4 kernel exists)
- All other vllm flags are critical for V4 (see `serve_vllm.py` for documentation)
---
## ⚠️ Model Config Patches (post-export)
modelopt 0.45.0.dev64's export produces configs that don't match what vllm expects at runtime. **NVIDIA's own published NVFP4 exports have the same gaps** — we compared against `nvidia/DeepSeek-V3.2-NVFP4` and `nvidia/MiniMax-M2.7-NVFP4` on HuggingFace. Neither includes `compress_ratios` or `scale_fmt` either. This is a modelopt ↔ vllm integration gap, not a problem with our quantization.
All patches below are to `DeepSeek-V4-Pro-NVFP4/config.json` unless noted.
| # | Field | modelopt export (original) | vllm requires | Patch applied | Why modelopt doesn't export it |
|---|-------|---------------------------|--------------|---------------|------------------------------ |
| 1 | `compress_ratios` | Missing (transformers 5.8.0 renamed to `compress_rates` dict) | List of ints indexed by layer_id | Copied from BF16 source model's `compress_ratios` (62 items) | modelopt doesn't add fields the source config lacks; transformers 5.8.0 renamed the field |
| 2 | `quantization_config.scale_fmt` | Missing | `"ue8m0"` string | Added | modelopt doesn't include vllm-specific runtime fields |
| 3 | `rope_parameters` | Nested dict `{'main': {...}, 'compress': {...}}` (transformers 5.8.0 format) | Flat dict `{'rope_theta': ..., 'rope_type': ..., ...}` | Flattened to `main` sub-dict | transformers 5.8.0 changed rope_parameters from flat → nested per-component |
| 4 | `rope_scaling` | Nested dict `{'main': {...}, 'compress': {...}}` (same as above) | Flat dict | Flattened to `main` sub-dict | Same transformers 5.8.0 schema change |
**NVIDIA's own NVFP4 exports confirmed to also lack patches 1 and 2.** We checked:
- `nvidia/DeepSeek-V3.2-NVFP4` — no `compress_ratios`, no `scale_fmt`, no `quantization_config` in config.json at all (V3.2 doesn't use MLA compression so it sidesteps the issue)
- `nvidia/MiniMax-M2.7-NVFP4` — has `quantization_config` in config.json (same schema as ours) but no `scale_fmt`
The `compress_rates``compress_ratios` rename and `rope_parameters` nesting are transformers 5.8.0 regressions that modelopt doesn't account for. `scale_fmt` is a vllm runtime field that modelopt has never exported.
- **IP**: `45.76.247.107`
- **User**: `root`
- **Password**: see `.env`
- **GPUs**: 8× NVIDIA B200 (SM100)
- **RAM**: ~2.7 TB
- **Model weights**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4/`
- **BF16 reference**: `/root/nvidia-meeting/DeepSeek-V4-Pro-BF16/`
## Architecture
We call modelopt's `hf_ptq.main()` directly — the same entry point the shell script uses. We don't rewrite the pipeline. We just:
```
DeepSeek V4 Pro (1.2T params, 61 layers)
├── MLA Attention (61 layers)
│ ├── fused_wqa_wkv → BF16 (UnquantizedLinearMethod)
│ ├── wo_a → FP8 (DeepGEMM block-scale, BMM einsum)
│ ├── wo_b → BF16 (UnquantizedLinearMethod)
│ └── compressor.fused_wkv_wgate → BF16 (reconstructed from NVFP4)
├── MoE Experts (384 experts, 61 layers)
│ ├── w13_weight → NVFP4 (FusedMoE, FLASHINFER_TRTLLM backend)
│ └── w2_weight → NVFP4 (FusedMoE, FLASHINFER_TRTLLM backend)
└── Shared Expert → FP8 (Fp8LinearMethod, DeepGEMM)
```
1. **Patch** modelopt at runtime (GPU tensor safety, before anything runs)
2. **Hook** `export_quantized` to snapshot amax + save state before export
3. **Call** `hf_main(args)` with properly parsed args
## The NVFP4 → vLLM Gap
This avoids the cascade of missing-arg bugs from manually constructing `argparse.Namespace` (Runs 48).
ModelOpt quantizes to NVFP4 (4-bit FP4 with block scales). vLLM's DeepSeek V4
attention code expects FP8 with DeepGEMM block-scale einsum. These formats were
**never integrated** — we're ahead of NVIDIA on this. Key gaps we had to bridge:
## Pipeline
### 1. wo_a: NVFP4 → FP8 + DeepGEMM Block Scale
### Step 1: Dequantize FP8 → BF16
```bash
python3 scripts/dequant_fp8_to_bf16.py /root/nvidia-meeting/DeepSeek-V4-Pro-FP8 /root/nvidia-meeting/DeepSeek-V4-Pro-BF16
```
The original V4 weights use mixed precision (FP8 attention + FP4/E2M1 experts with per-tensor scales). We dequantize everything to pure BF16 so modelopt can run calibration without hitting broken FP8 kernel paths on Blackwell (DeepGEMM unsupported, Triton finegrained FP8 matmul shape mismatches).
This is not a blind upcast — it applies the actual scale factors:
```
W_bf16 = dequantize_fp4_weight(W_int, S) # per-tensor scale dequant, not .to(bfloat16)
```
**Byte-exact verified** — matmul diff is 0.000000 against the official inference path.
### Step 2: Run NVFP4 Quantization
```bash
cd /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py
```
Must run from the modelopt example directory (relative imports).
What happens inside:
1. **Apply patches** — 3 runtime monkey-patches for GPU tensor safety (see below)
2. **Parse args** — uses `hf_ptq.parse_args()` with our config via `sys.argv` replacement, then applies the same post-parse conversions (`dataset` split, `calib_size` int list) that `hf_ptq.__main__` normally does
3. **Hook export** — monkey-patch `export_quantized` to snapshot amax + save state before export
4. **Call `hf_main(args)`** — the exact same pipeline the shell script uses
If the export crashes:
```bash
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --export-only
```
To validate saved state without running anything:
```bash
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --validate-only
```
**Config:** `nvfp4`, 128 calib samples, `calib_seq=512`, `kv_cache_qformat=fp8_cast`, `gpu_max_mem_percentage=0.7`, `use_seq_device_map`, `inference_tensor_parallel=8`
**Calibration datasets:** `abisee/cnn_dailymail` + `nvidia/Nemotron-Post-Training-Dataset-v2` (default when no `--dataset` specified).
**Runtime:** Model loading ~50 min. Calibration ~5.5 hours. Export ~30-60 min. Total 7-8 hours.
### Step 3: Serve with vLLM
**Problem**: `wo_a` uses `deepseek_v4_fp8_einsum` (BMM with DeepGEMM), which expects:
- Weight: `float8_e4m3fn` in 3D shape `(g, r, d)` for batched matmul
- Scale: DeepGEMM-formatted block scale tensor (not a per-tensor scalar)
Our NVFP4 weights are uint8 packed FP4 with separate block/global scales.
**Solution** (`_convert_nvfp4_to_fp8`):
1. Unpack NVFP4 uint8 → BF16 using E2M1 lookup table
2. Dequantize: `weight_bf16 * block_scale * global_scale * input_scale`
3. Re-quantize BF16 → FP8 e4m3 with per-tensor scale (`w_amax / fp8_max`)
4. Create block scale tensor filled with `fp8_scale` (same scale for every 128×128 block)
5. Call `deepgemm_post_process_fp8_weight_block(wq, ws, quant_block_shape=(128,128), use_e8m0=True, is_bmm=True, bmm_batch_size=N)`
6. Store: `weight_scale_inv = dg_ws` (DeepGEMM-formatted scale), `weight = w_fp8` (3D BMM shape)
**Why `weight_scale_inv`?** The attention forward reads `self.wo_a.weight_scale_inv` as
`b_scale` for `deepseek_v4_fp8_einsum` → DeepGEMM `fp8_einsum`. This must be the
DeepGEMM block-scale tensor, not a per-tensor scalar.
**Why `fp8_scale` in the block scale (not all-ones)?** DeepGEMM divides by the block
scale at runtime. If the block scale is all-ones, it divides by 1.0, producing garbage.
Each block needs the actual per-tensor scale value.
### 2. Attention Layers: NVFP4 → BF16
**Problem**: `fused_wqa_wkv`, `wo_b` use standard `torch.nn.functional.linear`.
NVFP4 weights (uint8) can't be used directly.
**Solution** (`_convert_nvfp4_to_bf16`):
1. Unpack NVFP4 → BF16
2. Dequantize with block/global/input scales
3. Replace `mod.weight` with BF16 parameter
4. Set `quant_method = UnquantizedLinearMethod()`
5. Remove NVFP4 scale attributes (`weight_scale`, `weight_scale_2`, `input_scale`)
### 3. Compressor: Reconstructing fused_wkv_wgate from NVFP4
**Problem**: The compressor's `fused_wkv_wgate` is a `MergedColumnParallelLinear`
with `disable_tp=True`. NVFP4 uint8 data can't be loaded into the BF16 parameter
(shape mismatch: uint8 is half the input dim). The default weight loader silently
skips these weights, leaving the parameter uninitialized.
**Solution** (`_reconstruct_compressor_weight`):
1. Read original `kv_proj.weight` and `gate_proj.weight` directly from safetensors
2. Unpack NVFP4 → BF16, dequantize with scales
3. Concatenate: `fused = cat([wkv, wgate], dim=0)`
4. Replace the uninitialized parameter
**Critical detail**: The **indexer** compressor is at a different checkpoint path:
- Main: `model.layers.N.self_attn.compressor.{kv_proj,gate_proj}.weight`
- Indexer: `model.layers.N.self_attn.compressor.indexer.{kv_proj,gate_proj}.weight`
Using the wrong prefix loads the main compressor weight into the indexer's
`fused_wkv_wgate`, causing a 4× shape mismatch and `split_with_sizes` crash.
### 4. MoE Experts: NVFP4 FusedMoE
**Problem**: vLLM's DeepSeek V4 uses `DeepseekV4MegaMoEExperts` with DeepGEMM
grouped GEMM. NVFP4 experts need a different kernel path.
**Solution**: The existing `ModelOptNvFp4LinearMethod` + `FusedMoE` infrastructure
handles NVFP4 experts natively. We just need to:
- Keep expert weights as NVFP4 uint8 + block/global scales
- Use `FLASHINFER_TRTLLM` MoE backend (auto-selected)
- Skip any conversion in `process_weights_after_loading`
### 5. BF16 wo_a Layers: BF16 → FP8
**Problem**: Some `wo_a` layers were NOT quantized by modelopt (BF16 in checkpoint).
The attention forward still reads them as FP8 for the einsum path.
**Solution** (`_convert_bf16_to_fp8`): Same as #1 but skip the NVFP4 unpack step.
Directly quantize BF16 → FP8 with block scale.
## Bugs Found and Fixed
### DeepGEMM `sf.dim()` Assertion (layout.hpp:94)
- **Root cause**: `weight_scale_inv` was a 1D per-tensor scale `(g,)`. DeepGEMM expects
2D/3D block-scale tensor formatted by `transform_sf_into_required_layout`.
- **Fix**: Use `deepgemm_post_process_fp8_weight_block` to produce correctly formatted
block scales, store result in `weight_scale_inv`.
### Block Scale dtype (`float8_e4m3fn` vs `float32`)
- **Root cause**: `deepgemm_post_process_fp8_weight_block` expects `float32` or
`float8_e8m0fnu` block scales. We initially used `float8_e4m3fn`.
- **Fix**: Create block scale as `dtype=torch.float32`.
### Missing `deepgemm_post_process` args
- **Root cause**: Function signature changed to require `quant_block_shape` and `use_e8m0`.
- **Fix**: Pass `quant_block_shape=(128, 128)` and `use_e8m0=True`.
### Compressor Indexer Shape Mismatch
- **Root cause**: `_reconstruct_compressor_weight` used the same checkpoint prefix
for both main and indexer compressors. The indexer's keys have `.indexer.` in the path.
- **Fix**: Add `sub_path` parameter; pass `".indexer"` for indexer compressors.
### All-Ones Block Scale → Garbage Output
- **Root cause**: Block scale was `torch.ones(...)` (scale=1.0). DeepGEMM divides by
the block scale at runtime, so the output was divided by 1.0 instead of the actual
per-tensor scale, producing incoherent text.
- **Fix**: Use `torch.full(..., fp8_scale.item())` to fill the block scale with the
correct per-tensor FP8 quantization scale.
## Running
```bash
# On B200 node
cd /root/nvidia-meeting
docker compose up -d
# Check logs
docker logs -f nvidia-meeting-vllm-1
# Test
curl http://localhost:8000/v1/models
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "/model", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
```
Or without Docker:
## Files
```bash
source /root/nvidia-meeting/venv/bin/activate
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/serve_vllm.py
```
| File | Purpose |
|------|---------|
| `patches/deepseek_v4.py` | Main patch: NVFP4 post-load conversion, weight reconstruction, DeepGEMM block-scale |
| `patches/modelopt.py` | ModelOpt FP4 config patches for weight loading |
| `.env` | B200 node credentials |
| `docker-compose.yml` | Container config (8 GPU, TP=8, EP=8, NVFP4 quant) |
**Note:** `serve_vllm.py` still references `--moe-backend=deep_gemm_mega_moe`. This needs to be removed when mega_moe support is ready. For now, use the Docker Compose setup which has it removed.
## Quantization Run History
| Run | Date | Commit | Calib | Result | Root Cause | Fix |
|-----|------|--------|-------|--------|------------|-----|
| 1 | May 7 | shell wrapper | 256 | ❌ Batch probing crash | `o_b_proj` shape mismatch — finegrained_fp8 wraps MLA projections incorrectly with FP8 source | Use BF16 source (dequantized) |
| 2 | May 8-9 | shell wrapper | 128 | ❌ Export crash (calib ✅) | `get_activation_scaling_factor` reads stale GPU amax → CUDA illegal memory access | Snapshot amax to CPU after calibration |
| 3 | May 9 06:10 | `3907838` | 128 | ❌ Model loading OOM | `AutoModelForCausalLM.from_pretrained` OOM during expert weight `torch.cat` | Use modelopt `get_model()` with `max_memory` |
| 4 | May 9 ~07:00 | `86dd8df` | 128 | ❌ Import error | `mtq.KV_QUANT_CFG_CHOICES` doesn't exist — it's `hf_ptq.KV_QUANT_CFG_CHOICES` | Import from `hf_ptq`, not `mtq` |
| 5 | May 9 ~08:05 | `f9bbef8` | 128 | ❌ Same as Run 4 | Fix wasn't synced properly | Properly synced |
| 6 | May 9 ~09:25 | `6c1bff6` | 128 | ❌ Dataloader crash | `make_calib_dataloader` AttributeError — missing args | Added args to Namespace |
| 7 | May 9 ~13:40 | `25b4d8d` | 128 | ❌ Dataloader crash | `dataset=None`, `len()` on None | Provided dataset list |
| 8 | May 9 ~14:00 | `b2849a8` | 128 | ❌ Argparse crash | Wrong flag names (shell script names vs `hf_ptq.py` names) | Use `hf_ptq.py` flag names |
| 9 | May 9 ~14:30 | `a300302` | 128 | ❌ TypeError | Skipped `__main__` post-parse conversions (`calib_size` still string, not int list) | Apply same conversions after `parse_args()` |
| 10 | May 9 ~15:30 | `5a72da7` | 128 | ❌ Export crash (calib ✅) | `get_weight_scaling_factor` reads stale GPU weight → `cudaErrorIllegalAddress` | Patch `_export_quantized_weight` to force weight to CPU at entry point |
| 11 | May 9 ~22:50 | `07cd50e` | 128 | ✅ **SUCCESS** | — | 8 patches covering full export chain |
### Key Lessons (Quantization)
**Run 2 — Stale GPU tensors:** `use_seq_device_map` shuffles layers through GPU for calibration. Quantizer amax tensors sit in VRAM for 5+ hours while CUDA's allocator churns memory. By export time, the GPU tensor metadata is valid but the underlying memory has been recycled — reading it triggers `cudaErrorIllegalAddress`. Fix: copy amax to CPU immediately after calibration.
**Run 3 — Expert weight OOM:** `AutoModelForCausalLM.from_pretrained` does `torch.cat` on GPU for expert `gate_up_proj` (31.5GB alloc, 25.9GB free). Fix: use modelopt's `get_model()` which sets `max_memory` per GPU before loading. (Note: Run 10 uses `hf_main()` which calls `get_model()` internally.)
**Runs 48 — Pipeline rewriting trap:** Trying to reconstruct hf_ptq's pipeline by importing individual functions and building a fake `argparse.Namespace` causes an endless stream of missing-attribute and type errors. Each fix reveals the next bug. Fix: call `hf_main(args)` directly with a properly parsed args object.
**Run 9 — `__main__` gap:** `hf_ptq.py` does critical type conversions in its `__main__` block (string → list for `dataset`, string → int list for `calib_size`). When calling `main()` directly, these are skipped. Fix: apply the same conversions after `parse_args()`.
**Run 10 — Stale GPU weight tensors in export:** The amax patches (Patch 1-3) only cover quantizer state. The model *weights* themselves are also on stale GPU. `get_weight_scaling_factor` does `weight_scaling_factor_2.to(weight.device)` which triggers `cudaErrorIllegalAddress` because `weight` is on stale GPU. Fix: patch `_export_quantized_weight` (the entry point for each module's export) to force `weight` to CPU before any downstream code reads it. This covers the entire chain: `get_weight_scaling_factor`, `get_weights_scaling_factor_from_quantizer`, `to_quantized_weight`, `weight.to(dtype)` — all resolve to CPU because `weight.device` is CPU.
### Do NOT Repeat These Mistakes
- Don't use FP8 source model — kernel issues on Blackwell (Run 1)
- Don't use `--low_memory_mode` with V4 — meta device errors
- Don't use `calib_size=256` — OOMs with 3TB BF16 on CPU offload
- Don't use `AutoModelForCausalLM.from_pretrained` directly — OOM during expert weight concat (Run 3)
- Don't assume GPU tensor integrity after 5+ hours of sequential calibration (Run 2, Run 10)
- Don't rewrite the hf_ptq pipeline — call `hf_main()` directly (Runs 48)
- Don't skip the `__main__` post-parse conversions — `calib_size` must be int list, `dataset` must be list (Run 9)
- Don't use shell script arg names (`--quant`, `--calib`, `--kv_cache_quant`, `--tp`) — use `hf_ptq.py` names (`--qformat`, `--calib_size`, `--kv_cache_qformat`, `--inference_tensor_parallel`)
- Don't patch individual export functions one at a time — patch the entry point (`_export_quantized_weight`) so weight is on CPU for the entire chain (Run 10)
- Don't use runtime monkey-patching for vllm serving — workers are separate processes that don't inherit patches. Patch the source file directly instead.
## Runtime Patches Applied by quantize_nvfp4.py
These are monkey-patches applied at runtime — no modelopt source files are modified.
### Calibration-time patches (applied before pipeline runs)
1. **`TensorQuantizer.load_calib_amax`** — After calibration writes `_amax` to GPU, immediately moves it to CPU. Prevents stale GPU tensors.
2. **`TensorQuantizer.export_amax`** — If `_amax` is still on GPU at export time, moves to CPU before reading. Safety net.
3. **`NVFP4QTensor.get_activation_scaling_factor`** — Moves amax to CPU, clamps bad values instead of hard assert. Prevents crash on garbage from GPU corruption.
### Export-time patches (force stale GPU tensors to CPU at entry points)
4. **`_export_quantized_weight`** (KEY PATCH) — Forces weight + all quantizer state to CPU *before* any downstream code reads them. This is the entry point for exporting each linear layer. By forcing weight to CPU here, every downstream `.to(weight.device)` resolves to CPU, covering the entire chain: `get_weight_scaling_factor`, `get_weights_scaling_factor_from_quantizer`, `to_quantized_weight`, `weight.to(dtype)`.
5. **`_export_fused_experts`** — Same treatment for MoE expert weights (DeepseekV4Experts go through this path). Forces expert weights, buffers, and quantizer state to CPU.
6. **`to_quantized_weight`** — Forces weight and scaling factors to CPU. Redundant if Patch 4 works, but catches any code path that reaches this function without going through `_export_quantized_weight`.
7. **`get_weight_scaling_factor`** — Forces weight + quantizer to CPU. Redundant if Patch 4 works.
8. **`get_weight_scaling_factor_2`** — Forces quantizer state to CPU. Redundant if Patch 4 works.
Patches 6-8 are belt-and-suspenders. Patch 4 is the one that matters — it moves weight to CPU at the earliest possible point in the export chain, making all downstream stale GPU reads impossible.
### Post-Calibration Hook
`export_quantized` is monkey-patched to run these steps before the real export:
4. **`snapshot_amax_to_cpu()`** — Walks all quantizers, copies `_amax` to CPU, saves to disk (~50MB). Insurance policy.
5. **`force_all_amax_to_cpu()`** — Moves `_pre_quant_scale`, `_global_amax` to CPU too. Nuclear option.
6. **`save_calibrated_state()`** — Saves full model state dict to disk (~1.5TB). Enables `--export-only` recovery if export crashes.
## Bugs Found (V4 + modelopt 0.45.0.dev64)
1. ~~`QuantDeepseekV4Experts` AttributeError~~**Already fixed** in modelopt 0.45.0.dev64 (handles `nn.ModuleList` quantizers natively).
2. `--low_memory_mode` → meta device error. Don't use with V4.
3. Missing `kernels` package for FP8 ops. `pip install -U kernels`.
4. ~~Shell script arg names~~ — Resolved by calling `hf_main()` directly.
5. **Export crash — stale GPU tensors in `export_amax()`.** After hours of calibration, quantizer `_amax` on GPU becomes unreadable. Fixed by patching `export_amax` to move `_amax` to CPU before reading.
6. **Export crash — `assert torch.all(activation_scaling_factor > 0)`.** Amax values from stale GPU reads are garbage (zeros, negatives, NaN). Fixed by clamping instead of asserting, plus snapshotting valid amax to CPU before corruption can occur.
7. **Model loading OOM during expert weight conversion.** `AutoModelForCausalLM.from_pretrained` does `torch.cat` on GPU for expert `gate_up_proj` (31.5GB alloc), but only 25.9GB free with `device_map="sequential"`. Fixed by using modelopt's `get_model()` which sets `max_memory` per GPU before loading.
8. **Export crash — stale GPU weight tensors in `get_weight_scaling_factor`.** Patches 1-3 only covered quantizer amax. The model weights themselves are also on stale GPU. `weight_scaling_factor_2.to(weight.device)` triggers `cudaErrorIllegalAddress`. Fixed by patching `_export_quantized_weight` to force weight to CPU at the entry point, covering the entire export chain.
### Bugs Found (V4 NVFP4 + vLLM serving)
1. **modelopt uses `mlp`, vllm uses `ffn`** — Module naming mismatch. Fixed with substr mapping.
2. **modelopt uses `gate_proj`/`up_proj`/`down_proj`, vllm expects `w1`/`w3`/`w2`** — Expert weight naming mismatch. Fixed with regex mapping (only for `.experts.N.`, not `.shared_experts.`).
3. **modelopt uses `self_attn` prefix, vllm uses `attn.mla_attn`** — Attention module naming. Fixed with substr mapping.
4. **`kv_proj` maps to `wkv`, not `kv_proj`** — vllm stacks `wkv` + `wq_a` into `fused_wqa_wkv`. Fixed with substr mapping.
5. **`compressor.kv_proj``compressor.wkv`** — Similar stacking for compressor. Fixed with substr mapping.
6. **`compressor.kv_norm``attn.kv_norm`** — modelopt puts `kv_norm` under compressor, vllm has it at attention level. Fixed with substr mapping (must come before general compressor mapping).
7. **`MergedColumnParallelLinear` + NVFP4 incompatibility** — `ModelOptNvFp4Config.create_weights()` only handles `Linear`, not `MergedColumnParallelLinear`. This causes:
- Weight param created as bf16 instead of uint8 (PackedColumnParameter)
- `weight_scale`/`weight_scale_2`/`input_scale` not registered for stacked params
- `adjust_shard_indexes_for_packing` applies packed_factor to rows, but NVFP4 packs along columns
- **Workaround:** Unpack uint8→bf16 at load time, skip scale tensors, rely on `process_weights_after_loading` re-quantization
8. **No NVFP4 mega_moe kernel**`DeepseekV4MegaMoEExperts` expects MXFP4 (32-col blocks), modelopt exports NVFP4 (16-col blocks). No kernel exists. **Abandoned mega_moe**, using standard FusedMoE instead.
9. **`DeepseekV4ForCausalLM.hf_to_vllm_mapper` is a class attribute** — Runtime monkey-patching the factory function doesn't update the cached class attribute. Must patch the source file directly or update the class attribute explicitly.
10. **vllm workers are separate processes** — In-memory monkey-patches don't propagate to workers. Must patch the source file directly.
11. **modelopt exports params vllm doesn't have** — e.g., `compressor.position_bias`. Need resilient loading that skips unknown params.
## Dependencies (pinned versions)
- **nvidia-modelopt:** `0.45.0.dev64+g579fc6c31` (installed from git, not PyPI)
- **transformers:** `5.8.0.dev0` (from git, required for DeepSeekV4 support)
- **kernels:** latest (`pip install -U kernels` — needed for finegrained FP8 ops)
- **Python:** 3.10
The patches in `quantize_nvfp4.py` are for **modelopt 0.45.0.dev64** specifically. Later versions may include fixes natively — check before applying.
## Key Notes
- V4 is NOT BF16 — it ships as mixed-precision FP8/FP4. You MUST dequantize to BF16 first (Step 1).
- `--low_memory_mode` causes meta device errors with V4 — don't use.
- modelopt has no explicit V4 support — relies on auto-detection of fused experts.
- The calibration state save (`v4_nvfp4_calibrated_state.pt`) is ~1.5TB. It lives on NVMe, not in git.
- The amax snapshot (`v4_nvfp4_amax_snapshots.pt`) is ~50MB. Small, critical, cheap insurance.
- The script calls `hf_main(args)` — the exact same entry point as the shell script. No pipeline divergence.
- Must run from `/root/nvidia-meeting/modelopt-repo/examples/llm_ptq` (relative imports).
- For vllm serving, the patched `deepseek_v4.py` must be mounted into the container — workers don't inherit in-memory patches.
- We disabled `--moe-backend=deep_gemm_mega_moe` because no NVFP4 mega_moe kernel exists yet. Standard FusedMoE with `ModelOptNvFp4FusedMoE` handles expert weights correctly.
## File Layout
## Conversion Flow
```
scripts/
dequant_fp8_to_bf16.py — Step 1: FP8/FP4 → BF16 dequantization
quantize_nvfp4.py — Step 2: NVFP4 quantization (patches + hf_main)
serve_vllm.py — Step 3: vLLM serving (legacy, still has mega_moe flag)
patches/
deepseek_v4.py — Patched vllm source file (copied over original at container startup)
patch_vllm_weights.py — Legacy runtime monkey-patch (doesn't work with workers, kept for reference)
quant_module_patched.py — (legacy) quant module patches
patch_finegrained_fp8_blackwell.py — (legacy) FP8 kernel patches for Blackwell
docker-compose.yml — Docker Compose config for serving (uses patched deepseek_v4.py, no mega_moe)
Checkpoint (NVFP4 safetensors)
├── [weight loader] ──→ vLLM model (NVFP4 uint8 params)
└── [process_weights_after_loading]
├── wo_a (is_bmm=True):
NVFP4→BF16→FP8 + DeepGEMM block scale
│ weight_scale_inv = dg_ws, weight = 3D FP8
├── fused_wqa_wkv, wo_b, shared_expert:
│ NVFP4→BF16, UnquantizedLinearMethod
├── compressor.fused_wkv_wgate:
│ Read kv_proj+gate_proj from checkpoint
│ NVFP4→BF16, cat into fused weight
└── MoE experts: stay NVFP4 (FusedMoE backend)
```
The `patches/` directory contains earlier approaches that modified modelopt source files directly. The current approach (`quantize_nvfp4.py`) uses runtime monkey-patching instead — no source files are modified.
## Known Issues
1. **Output quality**: FP4 is very aggressive quantization. The model produces
tokens but they may be incoherent. This could be:
- Normal FP4 quality degradation
- Subtle dequantization bugs (sign handling, scale ordering)
- The per-tensor FP8 requantization of wo_a losing per-block precision
2. **Runtime performance**: Not yet benchmarked. The DeepGEMM einsum + FusedMoE
path should be efficient on B200, but the BF16 layers go through
`UnquantizedLinearMethod` which may be slower than dedicated kernels.
## Quantization Details
- **Model**: DeepSeek V4 Pro (1.2T parameters)
- **Format**: NVIDIA NVFP4 (4-bit floating point with 128-element block scales)
- **Tool**: modelopt 0.45.0.dev64 + transformers 5.8.0.dev0
- **Run**: Run 11 (881GB), 8× B200, ~$161/run
- **Checkpoint**: 95 safetensors shards

View File

@@ -5,6 +5,7 @@ from collections.abc import Callable, Iterable
from itertools import islice
import regex as re
import os
import torch
import torch.nn as nn
@@ -1597,7 +1598,413 @@ class DeepseekV4Model(nn.Module):
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer.ffn.finalize_mega_moe_weights()
def _convert_nvfp4_post_load(self):
"""Post-load conversion of NVFP4 weights for vLLM compatibility.
Strategy:
- wo_a: Convert to FP8 (attention forward reads weight/weight_scale_inv
directly and passes to deepseek_v4_fp8_einsum, bypassing quant_method)
- fused_wqa_wkv, wq_b, wo_b: Dequant NVFP4->bf16 (called via
.forward() which goes through quant_method; FP8 would dtype-mismatch)
- compressor.fused_wkv_wgate: Dequant NVFP4->bf16 (used via direct
torch.mm in attention parallel stream)
- shared_experts (gate_up_proj, down_proj): Dequant NVFP4->bf16
- MoE experts: Stay in native NVFP4 (ModelOptNvFp4FusedMoE)
"""
E2M1_LUT = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16
)
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
# wo_a: attention forward reads .weight and .weight_scale_inv directly
# for fp8_einsum. Only layer that needs FP8 conversion.
fp8_proj_names = {"wo_a"}
# Attention layers called via .forward() — need bf16
bf16_proj_names = {"fused_wqa_wkv", "wq_b", "wo_b"}
# Shared expert layers called via .forward() — need bf16
bf16_shared_names = {"gate_up_proj", "down_proj"}
fp8_converted = 0
fp8_from_bf16 = 0
bf16_converted = 0
compressor_converted = 0
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
# FP8 conversion: only wo_a
for proj_name in fp8_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight"):
continue
if mod.weight.dtype == torch.uint8:
# NVFP4 -> dequant to bf16 -> requant to FP8
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
fp8_converted += 1
elif mod.weight.dtype == torch.bfloat16:
# modelopt did NOT quantize o_a_proj — it's bf16 already.
# Convert bf16 -> FP8 directly for fp8_einsum path.
self._convert_bf16_to_fp8(mod, FP8_MAX)
fp8_from_bf16 += 1
# BF16 conversion: attention layers via .forward()
for proj_name in bf16_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
# Compressor: fused_wkv_wgate used via direct torch.mm
# Compressor weights were SKIPPED during loading (skip patterns)
# because the stacking weight_loader corrupts NVFP4 uint8 data.
# We reconstruct the bf16 weight from the individual sub-weights
# that were loaded separately before stacking.
# Note: compressor.kv_proj.weight and compressor.gate_proj.weight
# are skipped, so fused_wkv_wgate.weight is zeros (empty tensor).
# We need to manually create it.
mla_attn = getattr(attn, "mla_attn", None)
if mla_attn is not None:
compressor = getattr(mla_attn, "compressor", None)
if compressor is not None and hasattr(compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT)
# Indexer compressor (C4A layers only)
indexer = getattr(mla_attn, "indexer", None)
if indexer is not None:
idx_compressor = getattr(indexer, "compressor", None)
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer")
# Shared experts
ffn = layer.ffn
if hasattr(ffn, "shared_experts") and ffn.shared_experts is not None:
for proj_name in bf16_shared_names:
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
total_fp8 = fp8_converted + fp8_from_bf16
total_bf16 = bf16_converted + compressor_converted
if total_fp8 > 0 or total_bf16 > 0:
print(f"NVFP4 post-load: {fp8_converted} NVFP4->FP8, "
f"{fp8_from_bf16} BF16->FP8, "
f"{bf16_converted} attn/shared->BF16, "
f"{compressor_converted} compressor->BF16, "
f"MoE experts stay NVFP4")
def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut):
"""Dequantize NVFP4 weight to bf16 for normal .forward() path."""
w_uint8 = mod.weight.data
device = w_uint8.device
w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device)
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = mod.weight_scale.data.to(torch.float32)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
input_scale = (
mod.input_scale.data.max().item()
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
# Replace weight with bf16 version
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"weight_scale_inv"):
if hasattr(mod, attr):
delattr(mod, attr)
def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max):
"""Convert NVFP4 weight to FP8 for fp8_einsum path (wo_a only).
Uses DeepGEMM's deepgemm_post_process_fp8_weight_block to ensure
correct weight and scale format for fp8_einsum with BMM.
"""
w_uint8 = mod.weight.data
device = w_uint8.device
w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device)
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = mod.weight_scale.data.to(torch.float32)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
input_scale = (
mod.input_scale.data.max().item()
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
# Re-quantize bf16 -> FP8 e4m3 with block quantization
# DeepGEMM expects block-scale format: weight_scale (FP8 e4m3 block scale)
# and weight_scale_inv (per-tensor scale).
# We do per-tensor quantization, so block_scale is all-ones.
w_amax = w_dequant.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=device)
fp8_scale = w_amax / fp8_max
w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn)
# Create block scale filled with the per-tensor fp8_scale value.
# DeepGEMM divides by the block scale, so each block gets fp8_scale.
BLOCK_SIZE = 128
is_bmm = getattr(mod, "is_bmm", False)
bmm_batch_size = getattr(mod, "bmm_batch_size", 0)
# Weight is 2D (output_size, input_size) before BMM reshape
# Block scale shape: (output_size / BLOCK_SIZE, input_size / BLOCK_SIZE)
rows = w_fp8.size(0)
cols = w_fp8.size(1)
block_rows = rows // BLOCK_SIZE
block_cols = cols // BLOCK_SIZE
# Fill block scale with the per-tensor fp8_scale (NOT all-ones!)
# This is correct because we requantized with a single per-tensor scale,
# so every 128x128 block has the same scale = fp8_scale.
ws = torch.full((block_rows, block_cols), fp8_scale.item(), dtype=torch.float32, device=device)
# Use DeepGEMM's post-processing for proper layout transformation
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
deepgemm_post_process_fp8_weight_block,
)
w_fp8, ws = deepgemm_post_process_fp8_weight_block(
wq=w_fp8,
ws=ws,
quant_block_shape=(BLOCK_SIZE, BLOCK_SIZE),
use_e8m0=True, # scale_fmt=ue8m0
is_bmm=is_bmm,
bmm_batch_size=bmm_batch_size,
)
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
# weight_scale_inv is what the attention runtime reads as b_scale
# for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum.
# It must be the DeepGEMM-formatted block scale (dg_ws), NOT the
# per-tensor scalar. See: deepseek_v4_attention.py line 319.
mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False)
# weight_scale is not used at runtime for BMM layers; remove it
# to avoid confusing other code paths.
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""):
"""Reconstruct compressor fused_wkv_wgate from checkpoint.
Compressor weights are SKIPPED during loading because NVFP4 uint8 data
can't be loaded into bf16 MergedColumnParallelLinear params (shape mismatch).
We read the original uint8 data from the safetensors checkpoint, unpack
E2M1, dequantize, and stack into the fused weight param.
"""
import glob
from safetensors.torch import load_file
# Find the checkpoint directory
# The model weights are mounted at /model in Docker
ckpt_dir = "/model"
if not os.path.isdir(ckpt_dir):
print(f"WARNING: layer {layer_idx} compressor: checkpoint dir {ckpt_dir} not found")
return 0
# Determine the layer's compressor key prefix in the checkpoint
# Before mapper: model.layers.N.self_attn.compressor.{kv_proj,gate_proj}
# After mapper: model.layers.N.attn.mla_attn.compressor.{wkv,wgate}
# We read from checkpoint (before mapper), so use original names
layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}"
# Find which shard contains this layer's compressor weights
wkv_key = f"{layer_prefix}.kv_proj.weight"
wgate_key = f"{layer_prefix}.gate_proj.weight"
wkv_scale_key = f"{layer_prefix}.kv_proj.weight_scale"
wgate_scale_key = f"{layer_prefix}.gate_proj.weight_scale"
wkv_scale2_key = f"{layer_prefix}.kv_proj.weight_scale_2"
wgate_scale2_key = f"{layer_prefix}.gate_proj.weight_scale_2"
wkv_iscale_key = f"{layer_prefix}.kv_proj.input_scale"
wgate_iscale_key = f"{layer_prefix}.gate_proj.input_scale"
# Load from safetensors
wkv_uint8 = None
wgate_uint8 = None
wkv_block_scale = None
wgate_block_scale = None
wkv_global_scale = None
wgate_global_scale = None
wkv_input_scale = None
wgate_input_scale = None
shard_files = sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors")))
for shard_file in shard_files:
try:
shard_data = load_file(shard_file)
except Exception:
continue
if wkv_key in shard_data:
wkv_uint8 = shard_data[wkv_key]
wkv_block_scale = shard_data.get(wkv_scale_key)
wkv_global_scale = shard_data.get(wkv_scale2_key)
wkv_input_scale = shard_data.get(wkv_iscale_key)
if wgate_key in shard_data:
wgate_uint8 = shard_data[wgate_key]
wgate_block_scale = shard_data.get(wgate_scale_key)
wgate_global_scale = shard_data.get(wgate_scale2_key)
wgate_input_scale = shard_data.get(wgate_iscale_key)
if wkv_uint8 is not None and wgate_uint8 is not None:
break
if wkv_uint8 is None or wgate_uint8 is None:
# Layer might not have a compressor (compress_ratio=1 layers)
return 0
device = fused_mod.weight.device
wkv_uint8 = wkv_uint8.to(device)
wgate_uint8 = wgate_uint8.to(device)
# Unpack E2M1 FP4→bf16
wkv_bf16 = self._unpack_nvfp4_to_bf16(wkv_uint8, e2m1_lut, device)
wgate_bf16 = self._unpack_nvfp4_to_bf16(wgate_uint8, e2m1_lut, device)
# Dequantize with scales
def _dequant(w_bf16, block_scale, global_scale, input_scale):
if block_scale is not None and global_scale is not None:
block_scale = block_scale.to(device).to(torch.float32)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_exp = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_exp = block_scale
gs = global_scale.to(device).max().item()
inp_s = input_scale.to(device).max().item() if input_scale is not None else 1.0
w = w_bf16.float() * block_scale_exp * gs * inp_s
return w.to(torch.bfloat16)
return w_bf16
wkv_dequant = _dequant(wkv_bf16, wkv_block_scale, wkv_global_scale, wkv_input_scale)
wgate_dequant = _dequant(wgate_bf16, wgate_block_scale, wgate_global_scale, wgate_input_scale)
# Stack: concatenate along output dim (dim 0)
# fused_wkv_wgate.weight = cat([wkv, wgate], dim=0) → (2*head_dim, hidden_size)
w_fused = torch.cat([wkv_dequant, wgate_dequant], dim=0)
# DEBUG: log shapes to diagnose compressor weight mismatch
print(f"NVFP4 compressor layer {layer_idx}: wkv={wkv_dequant.shape}, wgate={wgate_dequant.shape}, fused={w_fused.shape}, existing_param={fused_mod.weight.shape}")
# Replace the weight
fused_mod.weight = torch.nn.Parameter(w_fused, requires_grad=False)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
fused_mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale", "weight_scale_inv"):
if hasattr(fused_mod, attr):
delattr(fused_mod, attr)
return 1
return 0
def _convert_bf16_to_fp8(self, mod, fp8_max):
"""Convert BF16 weight to FP8 for fp8_einsum path.
Used for wo_a which modelopt did NOT quantize (bf16 in checkpoint)
but which the attention forward reads as FP8 for deepseek_v4_fp8_einsum.
Uses DeepGEMM's post-processing for proper BMM + scale format.
"""
w_bf16 = mod.weight.data
device = w_bf16.device
# Re-quantize bf16 -> FP8 e4m3 with block quantization
w_amax = w_bf16.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=device)
fp8_scale = w_amax / fp8_max
w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)
BLOCK_SIZE = 128
is_bmm = getattr(mod, "is_bmm", False)
bmm_batch_size = getattr(mod, "bmm_batch_size", 0)
rows = w_fp8.size(0)
cols = w_fp8.size(1)
block_rows = rows // BLOCK_SIZE
block_cols = cols // BLOCK_SIZE
# Fill block scale with per-tensor fp8_scale (NOT all-ones!)
ws = torch.full((block_rows, block_cols), fp8_scale.item(), dtype=torch.float32, device=device)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
deepgemm_post_process_fp8_weight_block,
)
w_fp8, ws = deepgemm_post_process_fp8_weight_block(
wq=w_fp8,
ws=ws,
quant_block_shape=(BLOCK_SIZE, BLOCK_SIZE),
use_e8m0=True, # scale_fmt=ue8m0
is_bmm=is_bmm,
bmm_batch_size=bmm_batch_size,
)
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
# weight_scale_inv is what the attention runtime reads as b_scale
# for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum.
# It must be the DeepGEMM-formatted block scale (dg_ws), NOT the
# per-tensor scalar. See: deepseek_v4_attention.py line 319.
mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False)
# weight_scale is not used at runtime for BMM layers; remove it
# to avoid confusing other code paths.
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):
"""Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format."""
# Extract 4-bit FP4 values (0-15, bit 3 = sign)
even_raw = (w_uint8 & 0x0F).int()
odd_raw = ((w_uint8 >> 4) & 0x0F).int()
# Sign: 0-7 = positive, 8-15 = negative
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
# Magnitude index: lower 3 bits (0-7)
even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07]
odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07]
# Interleave and flatten
w_bf16 = torch.stack([even_vals, odd_vals], dim=-1)
w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16)
return w_bf16
@torch.compile(backend=current_platform.simple_compile_backend)
def hc_head(
hidden_states: torch.Tensor,
@@ -1663,10 +2070,15 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
# process_weights_after_loading re-quantize them.
# Must match ORIGINAL checkpoint key names (before substr renaming).
fused_skip_regex = {
# Compressor projections → fused_wkv_wgate (stacked)
# Compressor uses UnquantizedLinearMethod (quant_config=None),
# so it only has a bf16 weight param — no scale params registered.
# We unpack the NVFP4 uint8 weights to bf16 at load time.
# Compressor: SKIP ALL tensors. The compressor uses quant_config=None,
# so MergedColumnParallelLinear creates bf16 weight params. NVFP4 uint8
# checkpoint data can't be loaded into these params (shape mismatch:
# uint8 (head_dim, hidden_size//2) vs bf16 (head_dim, hidden_size)).
# The stacking weight_loader silently skips the sub-weights, leaving
# random bf16 initialization. We reconstruct the compressor weights
# manually in post-load conversion by reading from the checkpoint.
re.compile(r"\.compressor\.kv_proj\.weight$"): None,
re.compile(r"\.compressor\.gate_proj\.weight$"): None,
re.compile(r"\.compressor\.kv_proj\.weight_scale$"): None,
re.compile(r"\.compressor\.gate_proj\.weight_scale$"): None,
re.compile(r"\.compressor\.kv_proj\.weight_scale_2$"): None,
@@ -1793,6 +2205,7 @@ class DeepseekV4ForCausalLM(nn.Module):
loader = AutoWeightsLoader(self, skip_substrs=["mtp."])
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
self.model.finalize_mega_moe_weights()
self.model._convert_nvfp4_post_load()
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: