DeepSeek V4 Pro → NVFP4 Quantization + vLLM Serving

Full NVFP4 quantization of DeepSeek V4 Pro on a single B200 node (8× B200, 2.7TB RAM, 13TB NVMe). Result: 881GB NVFP4 (Run 11). Now working on vLLM serving of the quantized checkpoint.

Cost: ~$161/run at $23/hr (7 hours each). Don't waste runs.

Final Quantization Result (Run 11)

  • Output: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4 — 881GB, 95 safetensors
  • Config: nvfp4 full quantization, 128 calib samples, kv_cache_qformat=fp8_cast
  • Total runtime: ~7,783s (~2h10m end-to-end)
  • Peak GPU mem: ~163GB per B200
  • Amax snapshots: 47,696 quantizers, 15.4MB
  • Calibrated state: 721.4GB (insurance, can re-export with --export-only)
  • A few experts (11, 83, 100, 112, 254) had uncalibrated amax — weight-derived fallback used (normal for sparse MoE with 256 experts)

🔧 vLLM Serving (In Progress)

Current Status: Debugging weight loading

The modelopt NVFP4 export and vllm have a chain of incompatibilities. We're progressively fixing them. The fundamental problem: modelopt's NVFP4 quantization format and vllm's DeepSeek V4 serving code were never integrated. NVIDIA's own published NVFP4 exports (DeepSeek-V3.2, MiniMax-M2.7) don't have these issues because they don't use MLA attention compression or 256-expert MoE — both of which create stacked/fused weight parameters that modelopt doesn't account for.

Approach: Patched deepseek_v4.py + disabled mega_moe

Instead of runtime monkey-patching (which doesn't propagate to worker processes), we patch the vllm source file directly. The patched deepseek_v4.py is mounted into the Docker container and copied over the original before vllm starts.

We also disabled --moe-backend=deep_gemm_mega_moe because:

  1. The NVFP4 mega_moe kernel doesn't exist yet (NVIDIA hasn't built it)
  2. MegaMoE uses MXFP4 block scale format (32-col blocks), but modelopt exports NVFP4 (16-col blocks) — format mismatch
  3. MegaMoE doesn't register weight_scale_2 or input_scale params, so those scales would be silently dropped

Instead, we use the standard FusedMoE path with ModelOptNvFp4FusedMoE, which natively supports NVFP4 expert weights.

vLLM Serving Run History

# Date Approach Result Root Cause Fix/Next
S1 May 10 09:34 patch_vllm_weights.py runtime patch + mega_moe UnboundLocalError: name_mapped Expert weight names don't match any mapping → name_mapped never assigned Add gate_proj→w1, up_proj→w3, down_proj→w2 mappings
S2 May 10 ~10:30 Same, added expert rename regexes Same error DeepseekV4ForCausalLM.hf_to_vllm_mapper is a class attribute set at import time — patching the function doesn't update the cached mapper Patch the class attribute directly
S3 May 10 ~11:00 Patched class attr, but workers are separate processes Same error in workers Workers don't inherit in-memory patches — they fork before the patch runs Patch the source file directly (deepseek_v4.py)
S4 May 10 ~11:30 Direct source file patch + mega_moe KeyError: 'layers.0.mlp.experts.0.w2.weight' modelopt uses mlp, vllm uses ffn internally — missing .mlp..ffn. mapping Add substr mapping
S5 May 10 ~12:00 Added mlp→ffn mapping + mega_moe KeyError: 'fused_wkv_wgate.input_scale' Compressor fused params don't register input_scale/weight_scale_2 Add skip patterns for compressor/attention scale tensors
S6 May 10 ~12:30 Added skip patterns + mega_moe Shape mismatch: w2_weight_scale (7168, 96) vs (7168, 192) NVFP4 uses 16-col block scales, mega_moe expects 32-col MXFP4 — format incompatibility Abandon mega_moe — no NVFP4 mega_moe kernel exists
S7 May 10 ~13:00 Disabled mega_moe, standard FusedMoE fused_wkv_wgate.weight shape mismatch: param=(1024,7168) bf16, loaded=(512,3584) uint8 MergedColumnParallelLinear creates weight as bf16 (not uint8), but modelopt exports NVFP4 packed uint8. ModelOptNvFp4Config only handles Linear, not MergedColumnParallelLinear Unpack uint8→bf16 at load time
S8 May 10 ~13:30 Added E2M1 unpacking for fused weights KeyError: 'fused_wkv_wgate.weight_scale' No weight_scale param registered for MergedColumnParallelLinear (same ModelOptNvFp4Config gap) Skip all NVFP4 scale tensors for stacked/fused attention+compressor params
S9 May 10 ~14:00 Added weight_scale skip patterns KeyError: 'compressor.kv_norm.weight' modelopt puts kv_norm under compressor, vllm has it at attention level (attn.kv_norm) Add compressor.kv_normkv_norm mapping
S10 May 10 ~14:15 Fixed kv_norm mapping KeyError: 'compressor.position_bias' modelopt exports params that don't exist in the vllm model Make loading resilient to unknown params

Open Issues (as of May 10 ~16:00 UTC)

  1. MergedColumnParallelLinear + NVFP4 incompatibility — The core problem. ModelOptNvFp4Config.create_weights() only handles Linear layers. For MergedColumnParallelLinear (used for fused_wqa_wkv, fused_wkv_wgate, gate_up_proj):

    • Weight param is created as ModelWeightParameter (bf16) instead of PackedColumnParameter (uint8)
    • weight_scale, weight_scale_2, input_scale params are never registered
    • adjust_shard_indexes_for_packing applies packed_factor to rows, but NVFP4 packs along columns
    • Current workaround: unpack uint8→bf16 at load time, skip scale tensors, let process_weights_after_loading re-quantize. This loses the calibration-optimized scales for attention/compressor/shared_expert weights.
  2. No NVFP4 mega_moe kernel — We disabled mega_moe to avoid the format mismatch. Standard FusedMoE with ModelOptNvFp4FusedMoE works for expert weights, but loses the mega_moe performance optimization. When NVIDIA builds an NVFP4 mega_moe kernel, we can re-enable it.

  3. Resilient loading needed — modelopt exports params (e.g., compressor.position_bias) that don't exist in the vllm model. Need to skip unknown params gracefully instead of crashing.

  4. Expert weight_scale_2 handling with FusedMoE — The standard FusedMoE path registers w13_weight_scale_2 and w2_weight_scale_2, so expert global scales CAN be loaded. This works for experts. The issue is only with the stacked/fused attention params.

What Each Patch Does

patches/deepseek_v4.py — Patched vllm source file, copied over the original at container startup. Contains:

  • Regex mappings (applied first by WeightsMapper):
    • Skip weight_scale, weight_scale_2, input_scale for compressor/attention fused params (no stacked param registered)
    • Skip weight_scale, weight_scale_2, input_scale for shared expert gate/up projections (stacked into gate_up_proj)
    • Expert projection rename: gate_proj→w1, up_proj→w3, down_proj→w2 (only for .experts.N., not .shared_experts.)
  • Substr mappings (applied after regex):
    • Attention: self_attn→attn.mla_attn with proper sub-projection names
    • kv_norm moved from compressor to attention level
    • compressor.kv_proj→compressor.wkv, compressor.gate_proj→compressor.wgate
    • shared_experts.gate_proj→shared_experts.w1, shared_experts.up_proj→shared_experts.w3
    • .mlp.→.ffn. (modelopt uses mlp, vllm uses ffn)
  • E2M1 FP4→BF16 unpacking for stacked params: When a uint8 packed NVFP4 weight is loaded into a bf16 param (MergedColumnParallelLinear), unpack using the E2M1 lookup table
  • Resilient loading: Skip unknown params that modelopt exports but vllm doesn't have

patches/patch_vllm_weights.py — Legacy runtime monkey-patch approach. Doesn't work because vllm workers are separate processes that don't inherit in-memory patches. Kept for reference.

docker-compose.yml — Docker Compose config:

  • Copies patched deepseek_v4.py before starting vllm
  • Removed --moe-backend=deep_gemm_mega_moe (no NVFP4 kernel exists)
  • All other vllm flags are critical for V4 (see serve_vllm.py for documentation)

⚠️ Model Config Patches (post-export)

modelopt 0.45.0.dev64's export produces configs that don't match what vllm expects at runtime. NVIDIA's own published NVFP4 exports have the same gaps — we compared against nvidia/DeepSeek-V3.2-NVFP4 and nvidia/MiniMax-M2.7-NVFP4 on HuggingFace. Neither includes compress_ratios or scale_fmt either. This is a modelopt ↔ vllm integration gap, not a problem with our quantization.

All patches below are to DeepSeek-V4-Pro-NVFP4/config.json unless noted.

# Field modelopt export (original) vllm requires Patch applied Why modelopt doesn't export it
1 compress_ratios Missing (transformers 5.8.0 renamed to compress_rates dict) List of ints indexed by layer_id Copied from BF16 source model's compress_ratios (62 items) modelopt doesn't add fields the source config lacks; transformers 5.8.0 renamed the field
2 quantization_config.scale_fmt Missing "ue8m0" string Added modelopt doesn't include vllm-specific runtime fields
3 rope_parameters Nested dict {'main': {...}, 'compress': {...}} (transformers 5.8.0 format) Flat dict {'rope_theta': ..., 'rope_type': ..., ...} Flattened to main sub-dict transformers 5.8.0 changed rope_parameters from flat → nested per-component
4 rope_scaling Nested dict {'main': {...}, 'compress': {...}} (same as above) Flat dict Flattened to main sub-dict Same transformers 5.8.0 schema change

NVIDIA's own NVFP4 exports confirmed to also lack patches 1 and 2. We checked:

  • nvidia/DeepSeek-V3.2-NVFP4 — no compress_ratios, no scale_fmt, no quantization_config in config.json at all (V3.2 doesn't use MLA compression so it sidesteps the issue)
  • nvidia/MiniMax-M2.7-NVFP4 — has quantization_config in config.json (same schema as ours) but no scale_fmt

The compress_ratescompress_ratios rename and rope_parameters nesting are transformers 5.8.0 regressions that modelopt doesn't account for. scale_fmt is a vllm runtime field that modelopt has never exported.

Architecture

We call modelopt's hf_ptq.main() directly — the same entry point the shell script uses. We don't rewrite the pipeline. We just:

  1. Patch modelopt at runtime (GPU tensor safety, before anything runs)
  2. Hook export_quantized to snapshot amax + save state before export
  3. Call hf_main(args) with properly parsed args

This avoids the cascade of missing-arg bugs from manually constructing argparse.Namespace (Runs 48).

Pipeline

Step 1: Dequantize FP8 → BF16

python3 scripts/dequant_fp8_to_bf16.py /root/nvidia-meeting/DeepSeek-V4-Pro-FP8 /root/nvidia-meeting/DeepSeek-V4-Pro-BF16

The original V4 weights use mixed precision (FP8 attention + FP4/E2M1 experts with per-tensor scales). We dequantize everything to pure BF16 so modelopt can run calibration without hitting broken FP8 kernel paths on Blackwell (DeepGEMM unsupported, Triton finegrained FP8 matmul shape mismatches).

This is not a blind upcast — it applies the actual scale factors:

W_bf16 = dequantize_fp4_weight(W_int, S)  # per-tensor scale dequant, not .to(bfloat16)

Byte-exact verified — matmul diff is 0.000000 against the official inference path.

Step 2: Run NVFP4 Quantization

cd /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py

Must run from the modelopt example directory (relative imports).

What happens inside:

  1. Apply patches — 3 runtime monkey-patches for GPU tensor safety (see below)
  2. Parse args — uses hf_ptq.parse_args() with our config via sys.argv replacement, then applies the same post-parse conversions (dataset split, calib_size int list) that hf_ptq.__main__ normally does
  3. Hook export — monkey-patch export_quantized to snapshot amax + save state before export
  4. Call hf_main(args) — the exact same pipeline the shell script uses

If the export crashes:

python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --export-only

To validate saved state without running anything:

python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/quantize_nvfp4.py --validate-only

Config: nvfp4, 128 calib samples, calib_seq=512, kv_cache_qformat=fp8_cast, gpu_max_mem_percentage=0.7, use_seq_device_map, inference_tensor_parallel=8

Calibration datasets: abisee/cnn_dailymail + nvidia/Nemotron-Post-Training-Dataset-v2 (default when no --dataset specified).

Runtime: Model loading ~50 min. Calibration ~5.5 hours. Export ~30-60 min. Total 7-8 hours.

Step 3: Serve with vLLM

cd /root/nvidia-meeting
docker compose up -d

Or without Docker:

source /root/nvidia-meeting/venv/bin/activate
python3 /root/nvidia-meeting/deepseek-v4-quant/scripts/serve_vllm.py

Note: serve_vllm.py still references --moe-backend=deep_gemm_mega_moe. This needs to be removed when mega_moe support is ready. For now, use the Docker Compose setup which has it removed.

Quantization Run History

Run Date Commit Calib Result Root Cause Fix
1 May 7 shell wrapper 256 Batch probing crash o_b_proj shape mismatch — finegrained_fp8 wraps MLA projections incorrectly with FP8 source Use BF16 source (dequantized)
2 May 8-9 shell wrapper 128 Export crash (calib ) get_activation_scaling_factor reads stale GPU amax → CUDA illegal memory access Snapshot amax to CPU after calibration
3 May 9 06:10 3907838 128 Model loading OOM AutoModelForCausalLM.from_pretrained OOM during expert weight torch.cat Use modelopt get_model() with max_memory
4 May 9 ~07:00 86dd8df 128 Import error mtq.KV_QUANT_CFG_CHOICES doesn't exist — it's hf_ptq.KV_QUANT_CFG_CHOICES Import from hf_ptq, not mtq
5 May 9 ~08:05 f9bbef8 128 Same as Run 4 Fix wasn't synced properly Properly synced
6 May 9 ~09:25 6c1bff6 128 Dataloader crash make_calib_dataloader AttributeError — missing args Added args to Namespace
7 May 9 ~13:40 25b4d8d 128 Dataloader crash dataset=None, len() on None Provided dataset list
8 May 9 ~14:00 b2849a8 128 Argparse crash Wrong flag names (shell script names vs hf_ptq.py names) Use hf_ptq.py flag names
9 May 9 ~14:30 a300302 128 TypeError Skipped __main__ post-parse conversions (calib_size still string, not int list) Apply same conversions after parse_args()
10 May 9 ~15:30 5a72da7 128 Export crash (calib ) get_weight_scaling_factor reads stale GPU weight → cudaErrorIllegalAddress Patch _export_quantized_weight to force weight to CPU at entry point
11 May 9 ~22:50 07cd50e 128 SUCCESS 8 patches covering full export chain

Key Lessons (Quantization)

Run 2 — Stale GPU tensors: use_seq_device_map shuffles layers through GPU for calibration. Quantizer amax tensors sit in VRAM for 5+ hours while CUDA's allocator churns memory. By export time, the GPU tensor metadata is valid but the underlying memory has been recycled — reading it triggers cudaErrorIllegalAddress. Fix: copy amax to CPU immediately after calibration.

Run 3 — Expert weight OOM: AutoModelForCausalLM.from_pretrained does torch.cat on GPU for expert gate_up_proj (31.5GB alloc, 25.9GB free). Fix: use modelopt's get_model() which sets max_memory per GPU before loading. (Note: Run 10 uses hf_main() which calls get_model() internally.)

Runs 48 — Pipeline rewriting trap: Trying to reconstruct hf_ptq's pipeline by importing individual functions and building a fake argparse.Namespace causes an endless stream of missing-attribute and type errors. Each fix reveals the next bug. Fix: call hf_main(args) directly with a properly parsed args object.

Run 9 — __main__ gap: hf_ptq.py does critical type conversions in its __main__ block (string → list for dataset, string → int list for calib_size). When calling main() directly, these are skipped. Fix: apply the same conversions after parse_args().

Run 10 — Stale GPU weight tensors in export: The amax patches (Patch 1-3) only cover quantizer state. The model weights themselves are also on stale GPU. get_weight_scaling_factor does weight_scaling_factor_2.to(weight.device) which triggers cudaErrorIllegalAddress because weight is on stale GPU. Fix: patch _export_quantized_weight (the entry point for each module's export) to force weight to CPU before any downstream code reads it. This covers the entire chain: get_weight_scaling_factor, get_weights_scaling_factor_from_quantizer, to_quantized_weight, weight.to(dtype) — all resolve to CPU because weight.device is CPU.

Do NOT Repeat These Mistakes

  • Don't use FP8 source model — kernel issues on Blackwell (Run 1)
  • Don't use --low_memory_mode with V4 — meta device errors
  • Don't use calib_size=256 — OOMs with 3TB BF16 on CPU offload
  • Don't use AutoModelForCausalLM.from_pretrained directly — OOM during expert weight concat (Run 3)
  • Don't assume GPU tensor integrity after 5+ hours of sequential calibration (Run 2, Run 10)
  • Don't rewrite the hf_ptq pipeline — call hf_main() directly (Runs 48)
  • Don't skip the __main__ post-parse conversions — calib_size must be int list, dataset must be list (Run 9)
  • Don't use shell script arg names (--quant, --calib, --kv_cache_quant, --tp) — use hf_ptq.py names (--qformat, --calib_size, --kv_cache_qformat, --inference_tensor_parallel)
  • Don't patch individual export functions one at a time — patch the entry point (_export_quantized_weight) so weight is on CPU for the entire chain (Run 10)
  • Don't use runtime monkey-patching for vllm serving — workers are separate processes that don't inherit patches. Patch the source file directly instead.

Runtime Patches Applied by quantize_nvfp4.py

These are monkey-patches applied at runtime — no modelopt source files are modified.

Calibration-time patches (applied before pipeline runs)

  1. TensorQuantizer.load_calib_amax — After calibration writes _amax to GPU, immediately moves it to CPU. Prevents stale GPU tensors.
  2. TensorQuantizer.export_amax — If _amax is still on GPU at export time, moves to CPU before reading. Safety net.
  3. NVFP4QTensor.get_activation_scaling_factor — Moves amax to CPU, clamps bad values instead of hard assert. Prevents crash on garbage from GPU corruption.

Export-time patches (force stale GPU tensors to CPU at entry points)

  1. _export_quantized_weight (KEY PATCH) — Forces weight + all quantizer state to CPU before any downstream code reads them. This is the entry point for exporting each linear layer. By forcing weight to CPU here, every downstream .to(weight.device) resolves to CPU, covering the entire chain: get_weight_scaling_factor, get_weights_scaling_factor_from_quantizer, to_quantized_weight, weight.to(dtype).
  2. _export_fused_experts — Same treatment for MoE expert weights (DeepseekV4Experts go through this path). Forces expert weights, buffers, and quantizer state to CPU.
  3. to_quantized_weight — Forces weight and scaling factors to CPU. Redundant if Patch 4 works, but catches any code path that reaches this function without going through _export_quantized_weight.
  4. get_weight_scaling_factor — Forces weight + quantizer to CPU. Redundant if Patch 4 works.
  5. get_weight_scaling_factor_2 — Forces quantizer state to CPU. Redundant if Patch 4 works.

Patches 6-8 are belt-and-suspenders. Patch 4 is the one that matters — it moves weight to CPU at the earliest possible point in the export chain, making all downstream stale GPU reads impossible.

Post-Calibration Hook

export_quantized is monkey-patched to run these steps before the real export:

  1. snapshot_amax_to_cpu() — Walks all quantizers, copies _amax to CPU, saves to disk (~50MB). Insurance policy.
  2. force_all_amax_to_cpu() — Moves _pre_quant_scale, _global_amax to CPU too. Nuclear option.
  3. save_calibrated_state() — Saves full model state dict to disk (~1.5TB). Enables --export-only recovery if export crashes.

Bugs Found (V4 + modelopt 0.45.0.dev64)

  1. QuantDeepseekV4Experts AttributeErrorAlready fixed in modelopt 0.45.0.dev64 (handles nn.ModuleList quantizers natively).
  2. --low_memory_mode → meta device error. Don't use with V4.
  3. Missing kernels package for FP8 ops. pip install -U kernels.
  4. Shell script arg names — Resolved by calling hf_main() directly.
  5. Export crash — stale GPU tensors in export_amax(). After hours of calibration, quantizer _amax on GPU becomes unreadable. Fixed by patching export_amax to move _amax to CPU before reading.
  6. Export crash — assert torch.all(activation_scaling_factor > 0). Amax values from stale GPU reads are garbage (zeros, negatives, NaN). Fixed by clamping instead of asserting, plus snapshotting valid amax to CPU before corruption can occur.
  7. Model loading OOM during expert weight conversion. AutoModelForCausalLM.from_pretrained does torch.cat on GPU for expert gate_up_proj (31.5GB alloc), but only 25.9GB free with device_map="sequential". Fixed by using modelopt's get_model() which sets max_memory per GPU before loading.
  8. Export crash — stale GPU weight tensors in get_weight_scaling_factor. Patches 1-3 only covered quantizer amax. The model weights themselves are also on stale GPU. weight_scaling_factor_2.to(weight.device) triggers cudaErrorIllegalAddress. Fixed by patching _export_quantized_weight to force weight to CPU at the entry point, covering the entire export chain.

Bugs Found (V4 NVFP4 + vLLM serving)

  1. modelopt uses mlp, vllm uses ffn — Module naming mismatch. Fixed with substr mapping.
  2. modelopt uses gate_proj/up_proj/down_proj, vllm expects w1/w3/w2 — Expert weight naming mismatch. Fixed with regex mapping (only for .experts.N., not .shared_experts.).
  3. modelopt uses self_attn prefix, vllm uses attn.mla_attn — Attention module naming. Fixed with substr mapping.
  4. kv_proj maps to wkv, not kv_proj — vllm stacks wkv + wq_a into fused_wqa_wkv. Fixed with substr mapping.
  5. compressor.kv_projcompressor.wkv — Similar stacking for compressor. Fixed with substr mapping.
  6. compressor.kv_normattn.kv_norm — modelopt puts kv_norm under compressor, vllm has it at attention level. Fixed with substr mapping (must come before general compressor mapping).
  7. MergedColumnParallelLinear + NVFP4 incompatibilityModelOptNvFp4Config.create_weights() only handles Linear, not MergedColumnParallelLinear. This causes:
    • Weight param created as bf16 instead of uint8 (PackedColumnParameter)
    • weight_scale/weight_scale_2/input_scale not registered for stacked params
    • adjust_shard_indexes_for_packing applies packed_factor to rows, but NVFP4 packs along columns
    • Workaround: Unpack uint8→bf16 at load time, skip scale tensors, rely on process_weights_after_loading re-quantization
  8. No NVFP4 mega_moe kernelDeepseekV4MegaMoEExperts expects MXFP4 (32-col blocks), modelopt exports NVFP4 (16-col blocks). No kernel exists. Abandoned mega_moe, using standard FusedMoE instead.
  9. DeepseekV4ForCausalLM.hf_to_vllm_mapper is a class attribute — Runtime monkey-patching the factory function doesn't update the cached class attribute. Must patch the source file directly or update the class attribute explicitly.
  10. vllm workers are separate processes — In-memory monkey-patches don't propagate to workers. Must patch the source file directly.
  11. modelopt exports params vllm doesn't have — e.g., compressor.position_bias. Need resilient loading that skips unknown params.

Dependencies (pinned versions)

  • nvidia-modelopt: 0.45.0.dev64+g579fc6c31 (installed from git, not PyPI)
  • transformers: 5.8.0.dev0 (from git, required for DeepSeekV4 support)
  • kernels: latest (pip install -U kernels — needed for finegrained FP8 ops)
  • Python: 3.10

The patches in quantize_nvfp4.py are for modelopt 0.45.0.dev64 specifically. Later versions may include fixes natively — check before applying.

Key Notes

  • V4 is NOT BF16 — it ships as mixed-precision FP8/FP4. You MUST dequantize to BF16 first (Step 1).
  • --low_memory_mode causes meta device errors with V4 — don't use.
  • modelopt has no explicit V4 support — relies on auto-detection of fused experts.
  • The calibration state save (v4_nvfp4_calibrated_state.pt) is ~1.5TB. It lives on NVMe, not in git.
  • The amax snapshot (v4_nvfp4_amax_snapshots.pt) is ~50MB. Small, critical, cheap insurance.
  • The script calls hf_main(args) — the exact same entry point as the shell script. No pipeline divergence.
  • Must run from /root/nvidia-meeting/modelopt-repo/examples/llm_ptq (relative imports).
  • For vllm serving, the patched deepseek_v4.py must be mounted into the container — workers don't inherit in-memory patches.
  • We disabled --moe-backend=deep_gemm_mega_moe because no NVFP4 mega_moe kernel exists yet. Standard FusedMoE with ModelOptNvFp4FusedMoE handles expert weights correctly.

File Layout

scripts/
  dequant_fp8_to_bf16.py   — Step 1: FP8/FP4 → BF16 dequantization
  quantize_nvfp4.py         — Step 2: NVFP4 quantization (patches + hf_main)
  serve_vllm.py             — Step 3: vLLM serving (legacy, still has mega_moe flag)

patches/
  deepseek_v4.py            — Patched vllm source file (copied over original at container startup)
  patch_vllm_weights.py     — Legacy runtime monkey-patch (doesn't work with workers, kept for reference)
  quant_module_patched.py   — (legacy) quant module patches
  patch_finegrained_fp8_blackwell.py  — (legacy) FP8 kernel patches for Blackwell

docker-compose.yml           — Docker Compose config for serving (uses patched deepseek_v4.py, no mega_moe)

The patches/ directory contains earlier approaches that modified modelopt source files directly. The current approach (quantize_nvfp4.py) uses runtime monkey-patching instead — no source files are modified.

Description
No description provided
Readme 1.6 MiB
Languages
Python 100%