Move online quantization to model.load_weights (#26327)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-11-18 16:52:41 -08:00
committed by GitHub
parent 1395461f5f
commit da94c7c0eb
6 changed files with 309 additions and 108 deletions

View File

@@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
get_quant_config,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
@@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.14.0")
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize,
)
# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(

View File

@@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
from collections.abc import Iterable
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import process_weights_after_loading
logger = init_logger(__name__)
@@ -56,6 +56,9 @@ logger = init_logger(__name__)
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# R5. (workaround for cudagraph), we restore the weight params to original quantized
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
# the weight update work well with cudagraph
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
@@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
if model_config.quantization != "torchao":
return
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.warning(
"process_weights_after_loading already called for model %s", model
)
return
from vllm.model_executor.model_loader.weight_utils import get_quant_config
quant_config = get_quant_config(model_config, None)
@@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
else:
model.recorded_weight_attr[name][key] = attr
# mark the metadata and attributes saved so we don't run it again
model._model_config = model_config
model.weight_metadata_and_attr_saved = True
@@ -148,77 +144,132 @@ def _bond_method_to_cls(func, obj):
return types.MethodType(func, obj)
def load_weights_and_online_quantize(
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
) -> set[str]:
def support_quantized_model_reload_from_hp_weights(original_load_weights):
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights
"""
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4 in the Notes
# R1, R2, R3, R4, R5 in the Notes
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
def patched_model_load_weights(
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
) -> set[str]:
model = auto_weight_loader.module
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see Notes in this file for more details
if offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
# Step R2: recover the parameter to the state before first loading
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
model_config = model._model_config
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
# TODO: maybe set remove_duplicate to True?
original_quantized_weight_dict = dict(
model.named_parameters(remove_duplicate=False)
)
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
else:
model_device = _device
if name in original_quantized_weight_dict:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(
torch.empty(_shape, dtype=_dtype, device=_device),
requires_grad=False,
),
)
# Step R2: recover the weight attributes to the state before first loading
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step R3: reload bfloat16 / high precision weights
updated_params = original_load_weights(
auto_weight_loader, weights, mapper=mapper
)
# Step R4: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
process_weights_after_loading(model, model_config, model_device)
else:
model_device = _device
if name in existing_param_names:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
logger.warning_once(
"model_device is None, skip calling process_weights_after_loading"
)
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step R5 (workaround for cudagraph): restore the original quantized weights
# and do a copy_ of the currents weights to the original weights
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
for name in model.original_weights_rebuild_keys:
if name in original_quantized_weight_dict:
original_quantized_weight = original_quantized_weight_dict[name]
updated_quantized_weight = updated_quantized_weights[name]
# Step I1: reload bfloat16 / high precision weights
loaded_weights = model.load_weights(
model_loader.get_all_weights(model_config, model)
)
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(module, weight_name, original_quantized_weight)
with torch.no_grad():
original_quantized_weight.copy_(updated_quantized_weight)
# Step I2: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
process_weights_after_loading(model, model_config, model_device)
model.process_weights_after_loading_already_called = True
return loaded_weights
del original_quantized_weight_dict
del named_modules
del updated_quantized_weight
model.process_weights_after_loading_already_called = True
return updated_params
return patched_model_load_weights

View File

@@ -88,6 +88,14 @@ def initialize_model(
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.debug_once(
"process_weights_after_loading already called for model %s", model
)
return
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading,

View File

@@ -21,6 +21,9 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
from vllm.model_executor.model_loader.online_quantization import (
support_quantized_model_reload_from_hp_weights,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import supports_any_eagle
from vllm.multimodal import NestedTensors
@@ -316,6 +319,7 @@ class AutoWeightsLoader:
)
raise ValueError(msg)
@support_quantized_model_reload_from_hp_weights
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],