# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base class for quantization modules.""" import contextlib import warnings from typing import Any import torch import torch.nn as nn from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls from modelopt.torch.utils.distributed import ParallelState from ...tensor_quant import QUANT_DESC_8BIT_PER_TENSOR from ...utils import is_torch_export_mode from .tensor_quantizer import SequentialQuantizer, TensorQuantizer __all__ = [ "QuantInputBase", "QuantLinearConvBase", "QuantModule", "QuantModuleRegistry", ] class QuantModule(DynamicModule): """A base class for quantized modules. In addition, the class also provides ``parallel_state`` attribute that can be used to access the parallel state of the module. """ _parallel_state: ParallelState @classmethod @torch.no_grad() def convert(cls, module: nn.Module, **setup_kwargs: Any) -> "QuantModule": """Convert the module to a dynamic module.""" module = super().convert(module, **setup_kwargs) # setup parallel state now that the module is converted if module.parallel_state is None: module._initialize_parallel_state() return module @property def parallel_state(self) -> ParallelState | None: """Return the parallel state of the quant module.""" return getattr(self, "_parallel_state", None) @parallel_state.setter def parallel_state(self, parallel_state: ParallelState): """Set the parallel state of the dynamic module.""" assert isinstance(parallel_state, ParallelState), ( "parallel_state must be a ParallelState object!" ) self._parallel_state = parallel_state def _initialize_parallel_state(self): """Initialize the parallel state of the dynamic module. This method is called only if the `QuantModule` does not have a `parallel_state` attribute after `_setup` is called. """ if torch.distributed.is_initialized(): warnings.warn( f"Distributed training is initialized but no parallel_state is set for {type(self)}. " "Using default parallel_state which has data_parallel_group set to the default process group and " "tensor_parallel_group is unspecified. " "If you are using tensor parallelism for this module, you should set the parallel_state " "in its `_setup` method." ) self.parallel_state = ParallelState(data_parallel_group=None) def modelopt_post_restore(self, prefix: str = ""): """Post-restore to correctly configure the TensorQuantizer states. TensorQuantizer states are restored to their shape before saving. Now we need to further configure them. 1. For non-sharded modules this simply involves moving the TensorQuantizer states to the right device. This applies for regular Pytorch models and HuggingFace models. 2. For sharded modules the restored states of TensorQuantizer could be incorrect. This is because parallelism such as TP might have been changed between saving and resoring. So we need to re-calculate the state shapes. Hence such modules should override this and implement their own logic. """ # Get a parameter or buffer that does not belong to a TensorQuantizer non_tq_param_or_buffer = None for name, param_or_buffer in self.state_dict().items(): parent = self.get_submodule(name.rsplit(".", 1)[0]) if "." in name else self if not isinstance(parent, TensorQuantizer): non_tq_param_or_buffer = param_or_buffer break if non_tq_param_or_buffer is None: warnings.warn( f"Could not identify the device for TensorQuantizer states of {prefix}. " "Please move the model to the right device now. This can be done by calling " "`model.to(device)`." ) return # Move the TensorQuantizer states to the right device (dtype should have been restored). for module in self.modules(): if isinstance(module, TensorQuantizer): module.to(non_tq_param_or_buffer.device) def iter_weights_for_calibration(self): """Yield ``(weight, weight_quantizer)`` pairs for weight-only calibration.""" import torch.nn as nn from modelopt.torch.quantization.utils import quantizer_attr_names, weight_attr_names for weight_name in weight_attr_names(self): qname = quantizer_attr_names(weight_name).weight_quantizer qattr = getattr(self, qname, None) weight = getattr(self, weight_name) if qattr is not None: # Singular quantizer yield weight, qattr else: # Try plural (ModuleList) - e.g. _QuantFusedExperts plural = qname + "s" qattr = getattr(self, plural, None) if isinstance(qattr, nn.ModuleList): # Yield per-expert slices for 3-D fused weights if weight.dim() == 3: for idx, q in enumerate(qattr): yield weight[idx], q else: for q in qattr: yield weight, q else: raise AttributeError( f"Cannot find weight quantizer {qname} or {plural} on {type(self).__name__}" ) def fold_weight(self, keep_attrs: bool = False): """Fold the weight for faster eval.""" # Handle all attributes that end with _weight_quantizer for name in dir(self): attr = getattr(self, name) if ( name.endswith("weight_quantizer") and isinstance(attr, TensorQuantizer) and attr.fake_quant ): # Get the corresponding weight name by removing _weight_quantizer suffix weight_name = name[:-10] assert hasattr(self, weight_name), ( f"{name} doesn't have a corresponding {weight_name} in {self.__class__.__name__}" ) weight = getattr(self, weight_name) weight.data.copy_(attr(weight.float()).to(weight.dtype)) attr.disable() if not keep_attrs: _attrs = [ "_pre_quant_scale", "_amax", ] for attr_name in _attrs: if hasattr(attr, attr_name): delattr(attr, attr_name) QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule) class QuantInputBase(QuantModule): """Base class for modules where the input is quantized.""" input_quantizer: TensorQuantizer output_quantizer: TensorQuantizer default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR def forward(self, input, *args, **kwargs): """Quantize the input before calling the original forward method.""" input = self.input_quantizer(input) # Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824 if hasattr(self, "_forward_pre_dm"): pre_fwd = getattr(self, "_forward_pre_dm") def _is_forward_in_mro(bound_or_func) -> bool: # If this is a bound method, compare its underlying function to any `forward` # implementation in the current MRO. If it matches, it's not an external monkey-patch. if hasattr(bound_or_func, "__func__"): fn = bound_or_func.__func__ for cls in type(self).mro(): if cls.__dict__.get("forward") is fn: return True return False if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd): output = super().forward(input, *args, **kwargs) else: output = pre_fwd(input, *args, **kwargs) else: output = super().forward(input, *args, **kwargs) if isinstance(output, tuple): return (self.output_quantizer(output[0]), *output[1:]) return self.output_quantizer(output) def _setup(self): """Patch the module's forward method to quantize the input.""" self._register_temp_attribute( "input_quantizer", TensorQuantizer(self.default_quant_desc_input) ) self._register_temp_attribute( "output_quantizer", TensorQuantizer(self.default_quant_desc_output) ) self.output_quantizer.disable() class QuantLinearConvBase(QuantInputBase): """Base class for quantized linear modules. Quantized linear modules are modules where both the input and the weight are quantized. """ weight_quantizer: TensorQuantizer | SequentialQuantizer _enable_weight_quantization: bool default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR @contextlib.contextmanager def quantize_weight(self): """Context in which `self.weight` is quantized.""" self._enable_weight_quantization = True try: yield finally: self._enable_weight_quantization = False @staticmethod def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor: if module._enable_weight_quantization or is_torch_export_mode(): return module.weight_quantizer(weight) return weight def forward(self, input, *args, **kwargs): """Quantize the input and the weight before calling the original forward method.""" # self.quntize_weight() setting attributes is not allowed for torch.export. if is_torch_export_mode(): return super().forward(input, *args, **kwargs) with self.quantize_weight(): return super().forward(input, *args, **kwargs) def _setup(self): super()._setup() self._register_temp_attribute( "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight) ) self._register_temp_attribute("_enable_weight_quantization", False) self._register_dynamic_attribute("weight", self._get_quantized_weight) class _LegacyQuantInputBaseMixin: """A mixin to support legacy quantized modules which needs to have an __init__ method.""" _quantized_cls = QuantInputBase default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR default_quant_desc_output = QUANT_DESC_8BIT_PER_TENSOR def __init__(self, *args, quant_desc_input=None, **kwargs): """Initialize the module with its original __init__ and patch its forward.""" self.default_quant_desc_input = quant_desc_input or self.default_quant_desc_input super().__init__(*args, **kwargs) QuantModuleRegistry.convert(self) class _LegacyQuantLinearConvBaseMixin(_LegacyQuantInputBaseMixin): """A mixin to support legacy quantized modules which needs to have an __init__ method.""" _quantized_cls = QuantLinearConvBase default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR def __init__(self, *args, quant_desc_input=None, quant_desc_weight=None, **kwargs): """Initialize the module with its original __init__ and patch its forward.""" self.default_quant_desc_weight = quant_desc_weight or self.default_quant_desc_weight super().__init__(*args, quant_desc_input=quant_desc_input, **kwargs)