[QeRL] Layerwise Reloading (#32133)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2026-01-30 10:50:05 -05:00
committed by GitHub
parent 74898a7015
commit f857a03f6b
17 changed files with 923 additions and 314 deletions

View File

@@ -1,275 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
from collections.abc import Iterable
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import process_weights_after_loading
logger = init_logger(__name__)
# Notes for Online Quantization
# In terms of state of checkpoints, quantization config and their
# correspondance to online quantization:
# | Use Case | Checkpoints | model_config.quantization |
# | no quant | high precision | None |
# | offline quant | quantized | fp8, torchao etc. |
# | online quant | high precision | torchao etc. |
#
# The process for loading non-quantized checkpoint
# 1. load non-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
#
# The process for loading offline quantized checkpoint
# 1. load offline-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
# The process for unquantized model reloading
# (repeated run in RL training loop)
# first run
# UI1. load_weights: load bfloat16 weights
# UI2. process_weights_after_loading: any additional post processing
# subsequent run
# UC1: load_weights: load bfloat16 weights
# (shouldn't be any issues since we didn't change any attributes
# of the weights)
# UC2: process_weights_after_loading: any additional post processing
# The process for weight reloading with online quantization
# (repeated run in RL training loop)
# first run
# I1. load_weights: load bfloat16 weights
# I2. process_weights_after_loading:
# record weight metadata and attributes for R1 and R2
# quantize weights to fp8
# subsequent run
# (beginning model weight is in fp8)
# load_weights:
# R1. restore bfloat16 model weight metadata
# R2. restore the model weight attributes
# R3. reload bfloat16 weights
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# R5. (workaround for cudagraph), we restore the weight params to original quantized
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
# the weight update work well with cudagraph
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
def maybe_save_metadata_and_attributes_for_weight_reloading(
model: nn.Module, model_config: ModelConfig
):
# following is to support on the fly quantization, currently only supported
# for torchao
if model_config.quantization != "torchao":
return
from vllm.model_executor.model_loader.weight_utils import get_quant_config
quant_config = get_quant_config(model_config, None)
# If checkpoint is already torchao serialized, this means it's
# pre-quantized quantization case, we'll skip saving the metadata
# Otherwise, this is Step I2 of initialization steps of
# online quantization
# This step record the weights metadata and weight attributes so we can
# restore the bfloat16 model weights during the relad step (R1 and R2)
# see Notes in online_quantization.py for more details
if not (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and not quant_config.is_checkpoint_torchao_serialized
):
return
# This is the I2 step of online quantiztion that saves
# metadata and attributes of weights so they can be used in R1 and
# R2 step, note that we only save these during initialization
# Includes two things
# 1. save floating point metadata (shape, dtype, device) for init
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
if getattr(model, "weight_metadata_and_attr_saved", False):
return
# save the dtype, shape and device for model parameter, used for
# restoring the model high precision parameters before
# reloading the weights
assert not hasattr(model, "original_weights_rebuild_keys")
model.original_weights_rebuild_keys = {}
for name, p in model.named_parameters():
model.original_weights_rebuild_keys[name] = {
"shape": p.shape,
"dtype": p.dtype,
"device": p.device,
}
# record the weight attributes (loader functions etc.)
# so these can be recovered later when we reload the weights
# structure: {"weight_name": {"weight_attr_key": attr}}
assert not hasattr(model, "recorded_weight_attr")
model.recorded_weight_attr = {}
for name, param in model.named_parameters():
model.recorded_weight_attr[name] = {}
for key in param.__dict__:
if hasattr(param, key):
attr = getattr(param, key)
if not callable(attr):
model.recorded_weight_attr[name][key] = attr
elif hasattr(attr, "__self__") and param is attr.__self__:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
model.recorded_weight_attr[name][key] = attr.__func__
else:
model.recorded_weight_attr[name][key] = attr
# mark the metadata and attributes saved so we don't run it again
model._model_config = model_config
model.weight_metadata_and_attr_saved = True
def _bond_method_to_cls(func, obj):
if hasattr(func, "__self__") or not callable(func):
# If the function is already bound to an instance, return it as is
return func
else:
return types.MethodType(func, obj)
def support_quantized_model_reload_from_hp_weights(original_load_weights):
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights
"""
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4, R5 in the Notes
def patched_model_load_weights(
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
) -> set[str]:
model = auto_weight_loader.module
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see Notes in this file for more details
if offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
model_config = model._model_config
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
# TODO: maybe set remove_duplicate to True?
original_quantized_weight_dict = dict(
model.named_parameters(remove_duplicate=False)
)
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
else:
model_device = _device
if name in original_quantized_weight_dict:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(
torch.empty(_shape, dtype=_dtype, device=_device),
requires_grad=False,
),
)
# Step R2: recover the weight attributes to the state before first loading
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step R3: reload bfloat16 / high precision weights
updated_params = original_load_weights(
auto_weight_loader, weights, mapper=mapper
)
# Step R4: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
if model_device is not None:
process_weights_after_loading(model, model_config, model_device)
else:
logger.warning_once(
"model_device is None, skip calling process_weights_after_loading"
)
# Step R5 (workaround for cudagraph): restore the original quantized weights
# and do a copy_ of the currents weights to the original weights
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
for name in model.original_weights_rebuild_keys:
if name in original_quantized_weight_dict:
original_quantized_weight = original_quantized_weight_dict[name]
updated_quantized_weight = updated_quantized_weights[name]
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(module, weight_name, original_quantized_weight)
with torch.no_grad():
original_quantized_weight.copy_(updated_quantized_weight)
del original_quantized_weight_dict
del named_modules
del updated_quantized_weight
model.process_weights_after_loading_already_called = True
return updated_params
return patched_model_load_weights

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layerwise weight reloading utilities for vLLM.
This module provides functionality to reload model weights layer-by-layer,
which is useful for weight updates without full model reconstruction.
Limitations:
1. Composition with CPU offloading has not been implemented
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
3. Tied parameters will only reflect processing from one of the parent layers (for
example, only processing from embed_tokens will have an effect)
4. This design assumes that the number of weights loaded from disk is the same as the
number of weights created at model init time. This is not true for quant methods
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
cases are non-issues for today's quant methods, but future quantizations may cause
reloading to fail
"""
__all__ = [
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
"set_torchao_reload_attrs",
"support_quantized_model_reload_from_hp_weights",
]
from .layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
record_metadata_for_reloading,
)
from .torchao_decorator import (
set_torchao_reload_attrs,
support_quantized_model_reload_from_hp_weights,
)

View File

@@ -0,0 +1,270 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps
from weakref import WeakKeyDictionary
import torch
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .meta import (
capture_layer_to_meta,
get_numel_loaded,
materialize_layer,
restore_layer_on_meta,
)
from .types import LayerReloadingInfo
from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors
logger = init_logger(__name__)
__all__ = [
"get_layerwise_info",
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
]
# Global dict storing information used for layerwise restoring, loading, and processing.
# For more information regarding what info is stored when, see `LayerReloadingInfo`
#
# Use a weak ref dictionary so that modules can be freed when the model is freed.
# Values are sanitized from references to the layer key in order to avoid circular refs
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
)
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
"""
Get information related to restoring and layerwise processing. If no previous
information existed, a new entry is constructed
"""
if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo()
return LAYERWISE_INFO[layer]
def record_metadata_for_reloading(model: torch.nn.Module):
"""
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration.
Must be called before `initialize_layerwise_reload`.
"""
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
"""
Set up layerwise weight loading with deferred processing.
Must be called after `record_metadata_for_reloading`. This function:
1. Saves current kernel tensors for later copying
2. Restores layer parameters/buffers from metadata (on meta device)
3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will:
1. Materialize the layer onto the target device
2. Load all cached weights
3. Run quantization processing if applicable
4. Copy processed values back to original tensor storage
"""
# disable torchao reloading to avoid infinite recursion
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
model._do_torchao_reload = False
for layer in model.modules():
info = get_layerwise_info(layer)
# Skip if the layer has already been initialized
if info.can_process():
continue
# Save current tensors for later copying
info.kernel_tensors = get_layer_params_buffers(layer)
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
"""Create a wrapped weight loader that defers processing."""
info = get_layerwise_info(layer)
param = getattr(layer, param_name)
original_loader = _get_original_loader(param)
loader_signature = inspect.signature(original_loader)
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
def online_process_loader(*args, **kwargs):
if not info.can_process():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
# weights (beyond load_numel_total) after it's already processed.
#
# Best solution is to ensure that `load_numel_total` reflects the
# actual number of weights loaded, either by modifying qconfigs to
# create as many weights as loaded (see padding issue as well)
# or maybe capturing how many weights are loaded on first pass
#
# For now, `load_numel_total` is still safe to use as long as
# there's no way to reach `load_numel_total` without loading all
# necessary weights. `weight_shape` is very small, so this is safe.
# see Limitations(4)
logger.debug("%s: Excessive loading", layer.__class__.__name__)
return
# Bind and normalize arguments
bound_args = loader_signature.bind(*args, **kwargs)
bound_args.apply_defaults()
# Cache loaded weights, track loading progress
info.loaded_weights.append((param_name, bound_args))
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
info.load_numel += num_loaded
logger.debug(
"%s: %d / %d",
layer.__class__.__name__,
info.load_numel,
info.load_numel_total,
)
# Process and copy when all weights are loaded
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
layer, (Attention, MLAAttention)
):
_layerwise_process(layer, info)
return ret
return online_process_loader
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
"""
Remove the outermost layer of weight loading wrappers.
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
"""
model._do_torchao_reload = model._original_do_torchao_reload
for layer in model.modules():
info = get_layerwise_info(layer)
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
# No weights were loaded, place kernel tensors back
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)
info.reset()
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
Finalize layer loading after all weights have been cached.
This function:
1. Materializes the layer onto the target device
2. Loads all cached weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
# Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
# Load all cached weights into materialized layer (using original loaders)
for name, args in info.loaded_weights:
param = getattr(layer, name)
args.arguments["param"] = param
param.weight_loader(*args.args, **args.kwargs)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
def _get_original_loader(tensor: torch.Tensor) -> Callable:
"""Return the weight loader with any layerwise wrappers removed"""
loader = _get_weight_loader(tensor)
while loader.__name__ == "online_process_loader":
loader = loader.__wrapped__ # type: ignore[union-attr]
return loader
def _get_weight_loader(tensor: torch.Tensor):
return getattr(tensor, "weight_loader", default_weight_loader)
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
layer.register_parameter(name, param)
for name, buffer in buffers.items():
layer.register_buffer(name, buffer)

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from .sanitize import restore_layer_refs, sanitize_layer_refs
from .types import LayerReloadingInfo, LayerTensors
from .utils import get_layer_params_buffers, get_layer_tensors
__all__ = [
"to_meta_tensor",
"materialize_meta_tensor",
"capture_layer_to_meta",
"restore_layer_on_meta",
"materialize_layer",
"get_numel_loaded",
]
SKIP_MODULES: set[str] = {"HadamardTransform"}
SKIP_TENSORS: set[str] = {
"_expert_map",
"expert_mask",
"expert_global_to_physical",
"expert_physical_to_global",
"expert_local_to_global",
}
def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Convert a tensor to a meta tensor while preserving class and attributes."""
meta_tensor = tensor.data.to("meta")
meta_tensor.__class__ = tensor.__class__
meta_tensor.__dict__ = tensor.__dict__.copy()
return meta_tensor
def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
"""
Materialize a meta tensor into an actual tensor on the current device.
Should be called within the torch device context for the given rank.
"""
tensor = torch.empty_strided(
size=tuple(meta_tensor.size()),
stride=tuple(meta_tensor.stride()),
dtype=meta_tensor.dtype,
requires_grad=False,
)
tensor.__class__ = meta_tensor.__class__
tensor.__dict__ = meta_tensor.__dict__.copy()
return tensor
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
if layer.__class__.__name__ in SKIP_MODULES:
return ({}, {})
params, buffers = get_layer_params_buffers(layer)
return (
{
name: sanitize_layer_refs(to_meta_tensor(param), layer)
for name, param in params.items()
if name not in SKIP_TENSORS
},
{
name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
for name, buffer in buffers.items()
if name not in SKIP_TENSORS
},
)
def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Restore a layer to model format with tensors on the meta device"""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name in get_layer_tensors(layer):
if name not in SKIP_TENSORS:
delattr(layer, name)
restore_params, restore_buffers = info.restore_metadata
for name, param in restore_params.items():
if name not in SKIP_TENSORS:
param = restore_layer_refs(param, layer)
layer.register_parameter(name, param)
for name, buffer in restore_buffers.items():
if name not in SKIP_TENSORS:
buffer = restore_layer_refs(buffer, layer)
layer.register_buffer(name, buffer)
def materialize_layer(layer: torch.nn.Module) -> None:
"""Materialize all meta tensors in a layer to actual tensors."""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))
class MetaCopyCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`.
Useful for keeping track of weight loading where underlying weights can be
arbitrarily transformed (such as with `narrow`) before calling copy.
Note: Assumes that copy kwargs are not used.
"""
def __init__(self):
super().__init__()
self.copied_numel = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
assert args[0].numel() == args[1].numel()
self.copied_numel += args[0].numel()
return func(*args, **kwargs)
def get_numel_loaded(
weight_loader: Callable, args: inspect.BoundArguments
) -> tuple[int, object]:
"""
Determine how many elements would be loaded by a weight loader call.
:param weight loader: used to load weights
:param args: bound arguments to weight loader
:return: number of elements loaded by the weight loader, the return value of the
weight loader
"""
assert args.arguments["param"].device.type == "meta"
with MetaCopyCounter() as counter:
return_value = weight_loader(*args.args, **args.kwargs)
return counter.copied_numel, return_value

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import MethodType
import torch
__all__ = ["sanitize_layer_refs", "restore_layer_refs"]
layer_ref_sentinel = object()
def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
"""
Removes references to layer held by tensor attributes. Specifically, removes the
`__self__` attribute of weight loader methods attached to the tensor.
Used by `capture_layer_to_meta` to avoid circular references to layers in
`LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation,
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
even when the model is deleted.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer:
tensor.__dict__[key] = value.__func__.__get__(layer_ref_sentinel)
return tensor
def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
"""
Restores references to layer held by tensor attributes.
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
weight loading.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
tensor.__dict__[key] = value.__func__.__get__(layer)
return tensor

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import wraps
from types import FunctionType
from typing import TYPE_CHECKING
import torch
from vllm.config import ModelConfig
from .layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
if TYPE_CHECKING:
from vllm.model_executor.models.utils import AutoWeightsLoader
__all__ = ["set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights"]
def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig):
model._do_torchao_reload = True
model._model_config = model_config
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
"""
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights.
Only applies to torchao quantized models. Assumes that all model weights are
loaded within a single weights iterator (cannot perform batched updates)
"""
@wraps(original_load_weights)
def patched_model_load_weights(
self: "AutoWeightsLoader",
weights: Iterable[tuple[str, torch.Tensor]],
*args,
**kwargs,
):
model = self.module
if not getattr(model, "_do_torchao_reload", False):
return original_load_weights(self, weights, *args, **kwargs)
initialize_layerwise_reload(model)
loaded_weights = original_load_weights(self, weights, *args, **kwargs)
finalize_layerwise_reload(model, model._model_config)
return loaded_weights
return patched_model_load_weights

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from inspect import BoundArguments
import torch
__all__ = ["LayerTensors", "LayerReloadingInfo"]
# encodes both parameters and buffers separately
LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
@dataclass
class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
# kernel format (device)
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
# track how many restored elements are ready for loading
load_numel: int = 0
load_numel_total: int | None = None
# stores arguments and tensors ready for loading
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
def can_process(self) -> bool:
return self.load_numel_total is not None

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from .types import LayerTensors
__all__ = [
"get_layer_tensors",
"get_layer_params_buffers",
"get_layer_size",
]
def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]:
"""Get all parameters and buffers from a module as a dict."""
params, buffers = get_layer_params_buffers(layer)
return params | buffers
def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors:
"""Get all parameters and buffers of a module as a tuple of dicts."""
return (
{name: param for name, param in layer._parameters.items() if param is not None},
{name: buffer for name, buffer in layer._buffers.items() if buffer is not None},
)
def get_layer_size(layer: torch.nn.Module) -> int:
"""Calculate total number of elements across all tensors in a layer."""
return sum(tensor.numel() for tensor in get_layer_tensors(layer).values())

View File

@@ -18,6 +18,10 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.model_loader.reload import (
record_metadata_for_reloading,
set_torchao_reload_attrs,
)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.utils.platform_utils import is_pin_memory_available
@@ -45,7 +49,9 @@ def initialize_model(
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
return model_class(vllm_config=vllm_config, prefix=prefix)
model = model_class(vllm_config=vllm_config, prefix=prefix)
record_metadata_for_reloading(model)
return model
msg = (
"vLLM model class should accept `vllm_config` and `prefix` as "
@@ -75,27 +81,15 @@ def initialize_model(
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
return model_class(**kwargs)
model = model_class(**kwargs)
record_metadata_for_reloading(model)
return model
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.debug_once(
"process_weights_after_loading already called for model %s", model
)
return
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading,
)
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
@@ -117,6 +111,11 @@ def process_weights_after_loading(
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
# Needed for torchao model reloading via model.reload_weights
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
if model_config.quantization == "torchao":
set_torchao_reload_attrs(model, model_config)
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):

View File

@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.model_executor.model_loader.online_quantization import (
from vllm.model_executor.model_loader.reload import (
support_quantized_model_reload_from_hp_weights,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

View File

@@ -522,8 +522,7 @@ class SharedWeightParameter(BasevLLMParameter):
@property
def data(self):
raise ValueError(
"Accessing `data` of a "
"`PartitionedModelWeightParameter` is not allowed. "
"Accessing `data` of a `SharedWeightParameter` is not allowed. "
"Instead, use `get_partition` to get the weight of "
"the particular partition you want to access"
)