[Quantization] - Added uses_meta_device_weights to quant config (#34645)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e89a91d927
commit
1faa8cb73c
@@ -18,6 +18,11 @@ else:
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
# Whether this method creates weights on meta device for online quantization.
|
||||
# When True, weights are created on meta device and quantized layer-wise
|
||||
# in process_weights_after_loading, reducing peak memory during loading.
|
||||
uses_meta_device: bool = False
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
|
||||
|
||||
@@ -527,6 +527,8 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
|
||||
and quantized the weights during loading."""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -1039,6 +1041,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(quant_config, layer)
|
||||
assert not quant_config.is_checkpoint_fp8_serialized
|
||||
|
||||
@@ -1092,16 +1092,20 @@ def initialize_dummy_weights(
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
# TODO(future PR): make the check below more generic as more online
|
||||
# quant backends are added
|
||||
is_fp8_py_quant = model_config.quantization == "fp8"
|
||||
|
||||
# Check if any module uses online quantization with meta device weights.
|
||||
# If so, we'll skip initializing params on meta device since they'll be
|
||||
# handled in `process_weights_after_loading`.
|
||||
def uses_meta_device(module: torch.nn.Module) -> bool:
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
return getattr(quant_method, "uses_meta_device", False)
|
||||
|
||||
has_online_quant = any(uses_meta_device(m) for m in model.modules())
|
||||
|
||||
for param in model.state_dict().values():
|
||||
if is_fp8_py_quant and param.device == torch.device("meta"):
|
||||
# for fp8.py's online quantization, dummy weight init will happen
|
||||
# in `process_weights_after_loading`.
|
||||
# TODO(future PR): consider refactoring dummy model init to compose
|
||||
# better with online quantization
|
||||
if has_online_quant and param.device == torch.device("meta"):
|
||||
# For online quantization, weights are created on meta device and
|
||||
# dummy weight init will happen in `process_weights_after_loading`.
|
||||
continue
|
||||
|
||||
initialize_single_dummy_weight(param, low, high, seed)
|
||||
|
||||
Reference in New Issue
Block a user