[QeRL] Compose online quantization with quantized reloading (#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2026-03-27 16:22:33 -04:00
committed by GitHub
parent 7ba425e916
commit 648edcf729
10 changed files with 184 additions and 260 deletions

View File

@@ -73,7 +73,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
@@ -496,8 +498,8 @@ class Fp8LinearMethod(LinearMethodBase):
class Fp8OnlineLinearMethod(Fp8LinearMethod):
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
and quantized the weights during loading."""
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
and quantizes weights during loading."""
uses_meta_device: bool = True
@@ -519,84 +521,25 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return res
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
weight_loader=weight_loader,
)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
layer.register_parameter("weight", weight)
initialize_online_processing(layer)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)
# TODO(future): support block_quant in online quant path
assert not self.block_quant
@@ -845,9 +788,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# Allow for accessing weights and scales in standard way.
w13 = layer.w13_weight
w2 = layer.w2_weight
@@ -892,9 +832,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
@@ -1013,86 +950,12 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# We are doing online quantization, patch the weight loaded
# to call `process_weights_after_loading` in a streaming fashion
# as soon as the last weight chunk is loaded.
weight_loader = extra_weight_attrs["weight_loader"]
# create a new holder to prevent modifying behavior of any other
# objects which might depend on the old one
new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call
# process_weights_after_loading
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
extra_weight_attrs = new_extra_weight_attrs
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
@@ -1106,91 +969,53 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts,
hidden_size,
intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
# Use the original weight_loader (not patched) for biases
orig_extra_weight_attrs = dict(extra_weight_attrs)
orig_extra_weight_attrs["weight_loader"] = weight_loader
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
torch.zeros(
num_experts,
hidden_size,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, orig_extra_weight_attrs)
set_weight_attrs(w2_bias, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
initialize_online_processing(layer)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
layer.w13_input_scale = None
layer.w2_input_scale = None
for expert in range(layer.local_num_experts):
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
@@ -1207,8 +1032,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
w2,
w13_scale,
w2_scale,
layer.w13_input_scale,
layer.w2_input_scale,
w13_input_scale=layer.w13_input_scale,
w2_input_scale=layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)

View File

@@ -337,6 +337,8 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
layer.w13_input_scale = None
layer.w2_input_scale = None
w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight)
w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight)

View File

@@ -9,6 +9,7 @@ import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.reload import finalize_layerwise_processing
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
@@ -49,16 +50,13 @@ class BaseModelLoader(ABC):
device_config.device if load_config.device is None else load_config.device
)
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
with set_default_torch_dtype(model_config.dtype), target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
log_model_inspection(model)
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
# Log peak GPU memory after loading weights. This is needed
@@ -71,6 +69,11 @@ class BaseModelLoader(ABC):
scope="local",
)
# Process weights into kernel format. Note that when using online
# quantization, weights are (typically) quantized as they are loaded.
if _has_online_quant(model):
finalize_layerwise_processing(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
@@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None:
from vllm.model_inspection import format_model_inspection
logger.info("vLLM model structure:\n%s", format_model_inspection(model))
def _has_online_quant(model: nn.Module):
for module in model.modules():
quant_method = getattr(module, "quant_method", None)
if getattr(quant_method, "uses_meta_device", False):
return True
return False

View File

@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
@@ -23,6 +26,12 @@ class DummyModelLoader(BaseModelLoader):
pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# materialize meta tensors as part of online quantization lifecycle
for layer in model.modules():
for name, param in get_layer_tensors(layer).items():
if param.device == torch.device("meta"):
setattr(layer, name, materialize_meta_tensor(param))
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model, model_config)

View File

@@ -21,12 +21,14 @@ Limitations:
__all__ = [
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_processing",
"finalize_layerwise_reload",
"set_torchao_reload_attrs",
"support_quantized_model_reload_from_hp_weights",
]
from .layerwise import (
finalize_layerwise_processing,
finalize_layerwise_reload,
initialize_layerwise_reload,
record_metadata_for_reloading,

View File

@@ -28,6 +28,7 @@ __all__ = [
"get_layerwise_info",
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_processing",
"finalize_layerwise_reload",
]
@@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module):
info = get_layerwise_info(layer)
# Skip if the layer has already been initialized
if info.can_process():
if info.can_load():
continue
# Save current tensors for later copying
@@ -98,15 +99,21 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
initialize_online_processing(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
def initialize_online_processing(layer: torch.nn.Module):
info = get_layerwise_info(layer)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
@@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
def online_process_loader(*args, **kwargs):
if not info.can_process():
if not info.can_load():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
@@ -140,7 +147,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
bound_args = loader_signature.bind(*args, **kwargs)
bound_args.apply_defaults()
# Cache loaded weights, track loading progress
# Buffer loaded weights, track loading progress
info.loaded_weights.append((param_name, bound_args))
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
info.load_numel += num_loaded
@@ -163,19 +170,26 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
return online_process_loader
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelConfig):
"""
Remove the outermost layer of weight loading wrappers.
Apply processing to any layers which were not layerwise processed during loading.
This includes attention layers and layers which have weight elements which are not
loaded (due to padding).
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
:param model: model to finalize processing for
:param model_config: config needed for applying processing to attention layers
"""
model._do_torchao_reload = model._original_do_torchao_reload
if hasattr(model, "_original_do_torchao_reload"):
model._do_torchao_reload = model._original_do_torchao_reload
for layer in model.modules():
info = get_layerwise_info(layer)
if not info.can_load():
info.reset()
continue
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
@@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
elif info.kernel_tensors is None:
raise NotImplementedError(
"Layerwise loading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
# No weights were loaded, place kernel tensors back
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
# No weights were loaded
elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load
if info.kernel_tensors is None:
materialize_layer(layer)
# reloading: place kernel tensors back as a fallback
else:
logger.warning("%s: Failed to load weights", layer.__class__.__name__)
_place_kernel_tensors(layer, info)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# Having too many of these delayed layers can lead to excess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
@@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
info.reset()
def finalize_layerwise_reload(*args, **kwargs):
finalize_layerwise_processing(*args, **kwargs)
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
Finalize layer loading after all weights have been cached.
Finalize layer loading after all weights have been buffered.
This function:
1. Materializes the layer onto the target device
2. Loads all cached weights
2. Loads all buffered weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
# Reset FP8 online quantization flag so process_weights_after_loading
# Reset online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")
@@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
# Load all cached weights into materialized layer (using original loaders)
# Load all buffered weights into materialized layer (using original loaders)
for name, args in info.loaded_weights:
param = getattr(layer, name)
args.arguments["param"] = param
@@ -239,13 +269,14 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
if info.kernel_tensors is not None:
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
_place_kernel_tensors(layer, info)
info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
@@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)
assert info.kernel_tensors is not None
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
layer.register_parameter(name, param)

View File

@@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None:
setattr(layer, name, materialize_meta_tensor(tensor))
class MetaCopyCounter(TorchDispatchMode):
class CopyCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`.
@@ -122,7 +122,7 @@ class MetaCopyCounter(TorchDispatchMode):
if kwargs is None:
kwargs = {}
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
if func is torch.ops.aten.copy_.default:
assert args[0].numel() == args[1].numel()
self.copied_numel += args[0].numel()
@@ -140,7 +140,6 @@ def get_numel_loaded(
:return: number of elements loaded by the weight loader, the return value of the
weight loader
"""
assert args.arguments["param"].device.type == "meta"
with MetaCopyCounter() as counter:
with CopyCounter() as counter:
return_value = weight_loader(*args.args, **args.kwargs)
return counter.copied_numel, return_value

View File

@@ -16,8 +16,8 @@ class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
# kernel format (device)
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
# kernel format (device), used to copy into when reloading only
kernel_tensors: LayerTensors | None = None
# track how many restored elements are ready for loading
load_numel: int = 0
@@ -29,5 +29,5 @@ class LayerReloadingInfo:
def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
def can_process(self) -> bool:
def can_load(self) -> bool:
return self.load_numel_total is not None

View File

@@ -1323,25 +1323,11 @@ def initialize_dummy_weights(
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
# Check if any module uses online quantization with meta device weights.
# If so, we'll skip initializing params on meta device since they'll be
# handled in `process_weights_after_loading`.
def uses_meta_device(module: torch.nn.Module) -> bool:
quant_method = getattr(module, "quant_method", None)
return getattr(quant_method, "uses_meta_device", False)
has_online_quant = any(uses_meta_device(m) for m in model.modules())
for param in model.state_dict().values():
if has_online_quant and param.device == torch.device("meta"):
# For online quantization, weights are created on meta device and
# dummy weight init will happen in `process_weights_after_loading`.
continue
initialize_single_dummy_weight(param, low, high, seed)
@torch.no_grad()
def initialize_single_dummy_weight(
param: torch.Tensor,
low: float = -1e-3,