Add ModelOpt NVFP4 pipeline: patch, run script, README
- Patch fixes iter_weights_for_calibration() for DeepseekV4Experts (ModuleList quantizers vs singular) - Run script uses official NVIDIA hf_ptq.py with FP8 source - Documents flags to avoid (--low_memory_mode, wrong arg names)
This commit is contained in:
38
README_modelopt_nvfp4.md
Normal file
38
README_modelopt_nvfp4.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# DeepSeek V4 Pro NVFP4 via NVIDIA ModelOpt
|
||||
|
||||
## What this does
|
||||
Quantizes DeepSeek V4 Pro (FP8 weights) to full NVFP4 format using NVIDIA's official ModelOpt pipeline.
|
||||
Target output: ~600GB (vs 840GB from custom Path A converter).
|
||||
|
||||
## Prerequisites
|
||||
- B200 node (8× B200, 2.7TB RAM) — NVFP4 requires Blackwell GPUs
|
||||
- modelopt 0.45.0+ from git
|
||||
- transformers 5.8.0.dev0 (for DeepSeekV4 support)
|
||||
- kernels package (for FP8 dequantization during calibration)
|
||||
|
||||
## Critical Patch
|
||||
modelopt has a bug with DeepSeekV4Experts — the `iter_weights_for_calibration()` method
|
||||
doesn't handle ModuleList quantizers (plural `gate_up_proj_weight_quantizers`).
|
||||
Apply the patch before running:
|
||||
|
||||
```bash
|
||||
cp patches/quant_module_patched.py <venv-path>/lib/python3.10/site-packages/modelopt/torch/quantization/nn/modules/quant_module.py
|
||||
```
|
||||
|
||||
## Do NOT use these flags
|
||||
- `--low_memory_mode`: causes meta device error with V4
|
||||
- `--calib_size`: wrong arg name (use `--calib`)
|
||||
|
||||
## Run
|
||||
```bash
|
||||
bash scripts/run_modelopt_nvfp4.sh
|
||||
```
|
||||
|
||||
## Output
|
||||
`/root/nvidia-meeting/modelopt-repo/examples/llm_ptq/saved_models_DeepSeek-V4-Pro-FP8_nvfp4_kv_fp8_cast`
|
||||
|
||||
## Notes
|
||||
- Use FP8 source (`DeepSeek-V4-Pro-FP8`), NOT mixed-precision BF16 (`DeepSeek-V4-Pro`)
|
||||
- V4's mixed precision causes "wonky shit" — FP8 is clean
|
||||
- Calibration takes hours with CPU offload (`--use_seq_device_map`)
|
||||
- Expected calibration time: several hours for 256 samples
|
||||
295
patches/quant_module_patched.py
Normal file
295
patches/quant_module_patched.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Base class for quantization modules."""
|
||||
|
||||
import contextlib
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
|
||||
from modelopt.torch.utils.distributed import ParallelState
|
||||
|
||||
from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR
|
||||
from ...utils import is_torch_export_mode
|
||||
from .tensor_quantizer import SequentialQuantizer, TensorQuantizer
|
||||
|
||||
__all__ = [
|
||||
"QuantInputBase",
|
||||
"QuantLinearConvBase",
|
||||
"QuantModule",
|
||||
"QuantModuleRegistry",
|
||||
]
|
||||
|
||||
|
||||
class QuantModule(DynamicModule):
|
||||
"""A base class for quantized modules.
|
||||
|
||||
In addition, the class also provides ``parallel_state`` attribute that can be used to access
|
||||
the parallel state of the module.
|
||||
"""
|
||||
|
||||
_parallel_state: ParallelState
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def convert(cls, module: nn.Module, **setup_kwargs: Any) -> "QuantModule":
|
||||
"""Convert the module to a dynamic module."""
|
||||
module = super().convert(module, **setup_kwargs)
|
||||
|
||||
# setup parallel state now that the module is converted
|
||||
if module.parallel_state is None:
|
||||
module._initialize_parallel_state()
|
||||
|
||||
return module
|
||||
|
||||
@property
|
||||
def parallel_state(self) -> ParallelState | None:
|
||||
"""Return the parallel state of the quant module."""
|
||||
return getattr(self, "_parallel_state", None)
|
||||
|
||||
@parallel_state.setter
|
||||
def parallel_state(self, parallel_state: ParallelState):
|
||||
"""Set the parallel state of the dynamic module."""
|
||||
assert isinstance(parallel_state, ParallelState), (
|
||||
"parallel_state must be a ParallelState object!"
|
||||
)
|
||||
self._parallel_state = parallel_state
|
||||
|
||||
def _initialize_parallel_state(self):
|
||||
"""Initialize the parallel state of the dynamic module.
|
||||
|
||||
This method is called only if the `QuantModule` does not have a `parallel_state` attribute
|
||||
after `_setup` is called.
|
||||
"""
|
||||
if torch.distributed.is_initialized():
|
||||
warnings.warn(
|
||||
f"Distributed training is initialized but no parallel_state is set for {type(self)}. "
|
||||
"Using default parallel_state which has data_parallel_group set to the default process group and "
|
||||
"tensor_parallel_group is unspecified. "
|
||||
"If you are using tensor parallelism for this module, you should set the parallel_state "
|
||||
"in its `_setup` method."
|
||||
)
|
||||
|
||||
self.parallel_state = ParallelState(data_parallel_group=None)
|
||||
|
||||
def modelopt_post_restore(self, prefix: str = ""):
|
||||
"""Post-restore to correctly configure the TensorQuantizer states.
|
||||
|
||||
TensorQuantizer states are restored to their shape before saving. Now we need to further configure them.
|
||||
1. For non-sharded modules this simply involves moving the TensorQuantizer states to the right device.
|
||||
This applies for regular Pytorch models and HuggingFace models.
|
||||
2. For sharded modules the restored states of TensorQuantizer could be incorrect. This is because
|
||||
parallelism such as TP might have been changed between saving and resoring. So we need to re-calculate
|
||||
the state shapes. Hence such modules should override this and implement their own logic.
|
||||
"""
|
||||
# Get a parameter or buffer that does not belong to a TensorQuantizer
|
||||
non_tq_param_or_buffer = None
|
||||
for name, param_or_buffer in self.state_dict().items():
|
||||
parent = self.get_submodule(name.rsplit(".", 1)[0]) if "." in name else self
|
||||
if not isinstance(parent, TensorQuantizer):
|
||||
non_tq_param_or_buffer = param_or_buffer
|
||||
break
|
||||
|
||||
if non_tq_param_or_buffer is None:
|
||||
warnings.warn(
|
||||
f"Could not identify the device for TensorQuantizer states of {prefix}. "
|
||||
"Please move the model to the right device now. This can be done by calling "
|
||||
"`model.to(device)`."
|
||||
)
|
||||
return
|
||||
|
||||
# Move the TensorQuantizer states to the right device (dtype should have been restored).
|
||||
for module in self.modules():
|
||||
if isinstance(module, TensorQuantizer):
|
||||
module.to(non_tq_param_or_buffer.device)
|
||||
|
||||
def iter_weights_for_calibration(self):
|
||||
"""Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration."""
|
||||
import torch.nn as nn
|
||||
from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names
|
||||
|
||||
for weight_name in weight_attr_names(self):
|
||||
qname = quantizer_attr_names(weight_name).weight_quantizer
|
||||
qattr = getattr(self, qname, None)
|
||||
weight = getattr(self, weight_name)
|
||||
if qattr is not None:
|
||||
# Singular quantizer
|
||||
yield weight, qattr
|
||||
else:
|
||||
# Try plural (ModuleList) - e.g. _QuantFusedExperts
|
||||
plural = qname + "s"
|
||||
qattr = getattr(self, plural, None)
|
||||
if isinstance(qattr, nn.ModuleList):
|
||||
# Yield per-expert slices for 3-D fused weights
|
||||
if weight.dim() == 3:
|
||||
for idx, q in enumerate(qattr):
|
||||
yield weight[idx], q
|
||||
else:
|
||||
for q in qattr:
|
||||
yield weight, q
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Cannot find weight quantizer {qname} or {plural} on {type(self).__name__}"
|
||||
)
|
||||
|
||||
def fold_weight(self, keep_attrs: bool = False):
|
||||
"""Fold the weight for faster eval."""
|
||||
# Handle all attributes that end with _weight_quantizer
|
||||
for name in dir(self):
|
||||
attr = getattr(self, name)
|
||||
if (
|
||||
name.endswith("weight_quantizer")
|
||||
and isinstance(attr, TensorQuantizer)
|
||||
and attr.fake_quant
|
||||
):
|
||||
# Get the corresponding weight name by removing _weight_quantizer suffix
|
||||
weight_name = name[:-10]
|
||||
|
||||
assert hasattr(self, weight_name), (
|
||||
f"{name} doesn't have a corresponding {weight_name} in {self.__class__.__name__}"
|
||||
)
|
||||
weight = getattr(self, weight_name)
|
||||
weight.data.copy_(attr(weight.float()).to(weight.dtype))
|
||||
attr.disable()
|
||||
if not keep_attrs:
|
||||
_attrs = [
|
||||
"_pre_quant_scale",
|
||||
"_amax",
|
||||
]
|
||||
for attr_name in _attrs:
|
||||
if hasattr(attr, attr_name):
|
||||
delattr(attr, attr_name)
|
||||
|
||||
|
||||
QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule)
|
||||
|
||||
|
||||
class QuantInputBase(QuantModule):
|
||||
"""Base class for modules where the input is quantized."""
|
||||
|
||||
input_quantizer: TensorQuantizer
|
||||
output_quantizer: TensorQuantizer
|
||||
default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
|
||||
default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
"""Quantize the input before calling the original forward method."""
|
||||
input = self.input_quantizer(input)
|
||||
# Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824
|
||||
if hasattr(self, "_forward_pre_dm"):
|
||||
pre_fwd = getattr(self, "_forward_pre_dm")
|
||||
|
||||
def _is_forward_in_mro(bound_or_func) -> bool:
|
||||
# If this is a bound method, compare its underlying function to any `forward`
|
||||
# implementation in the current MRO. If it matches, it's not an external monkey-patch.
|
||||
if hasattr(bound_or_func, "__func__"):
|
||||
fn = bound_or_func.__func__
|
||||
for cls in type(self).mro():
|
||||
if cls.__dict__.get("forward") is fn:
|
||||
return True
|
||||
return False
|
||||
|
||||
if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd):
|
||||
output = super().forward(input, *args, **kwargs)
|
||||
else:
|
||||
output = pre_fwd(input, *args, **kwargs)
|
||||
else:
|
||||
output = super().forward(input, *args, **kwargs)
|
||||
if isinstance(output, tuple):
|
||||
return (self.output_quantizer(output[0]), *output[1:])
|
||||
return self.output_quantizer(output)
|
||||
|
||||
def _setup(self):
|
||||
"""Patch the module's forward method to quantize the input."""
|
||||
self._register_temp_attribute(
|
||||
"input_quantizer", TensorQuantizer(self.default_quant_desc_input)
|
||||
)
|
||||
self._register_temp_attribute(
|
||||
"output_quantizer", TensorQuantizer(self.default_quant_desc_output)
|
||||
)
|
||||
self.output_quantizer.disable()
|
||||
|
||||
|
||||
class QuantLinearConvBase(QuantInputBase):
|
||||
"""Base class for quantized linear modules.
|
||||
|
||||
Quantized linear modules are modules where both the input and the weight are quantized.
|
||||
"""
|
||||
|
||||
weight_quantizer: TensorQuantizer | SequentialQuantizer
|
||||
_enable_weight_quantization: bool
|
||||
default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
|
||||
|
||||
@contextlib.contextmanager
|
||||
def quantize_weight(self):
|
||||
"""Context in which `self.weight` is quantized."""
|
||||
self._enable_weight_quantization = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._enable_weight_quantization = False
|
||||
|
||||
@staticmethod
|
||||
def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
|
||||
if module._enable_weight_quantization or is_torch_export_mode():
|
||||
return module.weight_quantizer(weight)
|
||||
return weight
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
"""Quantize the input and the weight before calling the original forward method."""
|
||||
# self.quntize_weight() setting attributes is not allowed for torch.export.
|
||||
if is_torch_export_mode():
|
||||
return super().forward(input, *args, **kwargs)
|
||||
|
||||
with self.quantize_weight():
|
||||
return super().forward(input, *args, **kwargs)
|
||||
|
||||
def _setup(self):
|
||||
super()._setup()
|
||||
self._register_temp_attribute(
|
||||
"weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
|
||||
)
|
||||
self._register_temp_attribute("_enable_weight_quantization", False)
|
||||
self._register_dynamic_attribute("weight", self._get_quantized_weight)
|
||||
|
||||
|
||||
class _LegacyQuantInputBaseMixin:
|
||||
"""A mixin to support legacy quantized modules which needs to have an __init__ method."""
|
||||
|
||||
_quantized_cls = QuantInputBase
|
||||
default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
|
||||
default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR
|
||||
|
||||
def __init__(self, *args, quant_desc_input=None, **kwargs):
|
||||
"""Initialize the module with its original __init__ and patch its forward."""
|
||||
self.default_quant_desc_input = quant_desc_input or self.default_quant_desc_input
|
||||
super().__init__(*args, **kwargs)
|
||||
QuantModuleRegistry.convert(self)
|
||||
|
||||
|
||||
class _LegacyQuantLinearConvBaseMixin(_LegacyQuantInputBaseMixin):
|
||||
"""A mixin to support legacy quantized modules which needs to have an __init__ method."""
|
||||
|
||||
_quantized_cls = QuantLinearConvBase
|
||||
default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
|
||||
|
||||
def __init__(self, *args, quant_desc_input=None, quant_desc_weight=None, **kwargs):
|
||||
"""Initialize the module with its original __init__ and patch its forward."""
|
||||
self.default_quant_desc_weight = quant_desc_weight or self.default_quant_desc_weight
|
||||
super().__init__(*args, quant_desc_input=quant_desc_input, **kwargs)
|
||||
25
scripts/run_modelopt_nvfp4.sh
Executable file
25
scripts/run_modelopt_nvfp4.sh
Executable file
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
# DeepSeek V4 Pro FP8 → NVFP4 via NVIDIA ModelOpt
|
||||
# Run from: /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
|
||||
#
|
||||
# Prerequisites:
|
||||
# - modelopt 0.45.0+ from git: pip install "nvidia-modelopt[hf] @ git+https://github.com/NVIDIA/Model-Optimizer.git"
|
||||
# - transformers 5.8.0.dev0: pip install git+https://github.com/huggingface/transformers.git
|
||||
# - kernels: pip install -U kernels
|
||||
# - Patch modelopt: cp patches/quant_module_patched.py <venv>/lib/python3.10/site-packages/modelopt/torch/quantization/nn/modules/quant_module.py
|
||||
#
|
||||
# Source weights: /root/nvidia-meeting/DeepSeek-V4-Pro-FP8
|
||||
|
||||
set -e
|
||||
cd /root/nvidia-meeting/modelopt-repo/examples/llm_ptq
|
||||
source /root/nvidia-meeting/venv/bin/activate
|
||||
|
||||
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
bash scripts/huggingface_example.sh \
|
||||
--model /root/nvidia-meeting/DeepSeek-V4-Pro-FP8 \
|
||||
--quant nvfp4 \
|
||||
--tp 8 \
|
||||
--calib 256 \
|
||||
--kv_cache_quant fp8_cast \
|
||||
--trust_remote_code \
|
||||
--use_seq_device_map
|
||||
Reference in New Issue
Block a user