[Quantization] - Added uses_meta_device_weights to quant config (#34645)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
Asaf Joseph Gardin
2026-02-18 09:43:44 +02:00
committed by GitHub
parent e89a91d927
commit 1faa8cb73c
3 changed files with 21 additions and 8 deletions

View File

@@ -18,6 +18,11 @@ else:
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
# Whether this method creates weights on meta device for online quantization.
# When True, weights are created on meta device and quantized layer-wise
# in process_weights_after_loading, reducing peak memory during loading.
uses_meta_device: bool = False
@abstractmethod
def create_weights(
self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs

View File

@@ -527,6 +527,8 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""
uses_meta_device: bool = True
def create_weights(
self,
layer: torch.nn.Module,
@@ -1039,6 +1041,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
quant_config: The quantization config.
"""
uses_meta_device: bool = True
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(quant_config, layer)
assert not quant_config.is_checkpoint_fp8_serialized

View File

@@ -1092,16 +1092,20 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
# TODO(future PR): make the check below more generic as more online
# quant backends are added
is_fp8_py_quant = model_config.quantization == "fp8"
# Check if any module uses online quantization with meta device weights.
# If so, we'll skip initializing params on meta device since they'll be
# handled in `process_weights_after_loading`.
def uses_meta_device(module: torch.nn.Module) -> bool:
quant_method = getattr(module, "quant_method", None)
return getattr(quant_method, "uses_meta_device", False)
has_online_quant = any(uses_meta_device(m) for m in model.modules())
for param in model.state_dict().values():
if is_fp8_py_quant and param.device == torch.device("meta"):
# for fp8.py's online quantization, dummy weight init will happen
# in `process_weights_after_loading`.
# TODO(future PR): consider refactoring dummy model init to compose
# better with online quantization
if has_online_quant and param.device == torch.device("meta"):
# For online quantization, weights are created on meta device and
# dummy weight init will happen in `process_weights_after_loading`.
continue
initialize_single_dummy_weight(param, low, high, seed)