Move online quantization to model.load_weights (#26327)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
fastsafetensors_weights_iterator,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
get_quant_config,
|
||||
maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
@@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao" and torchao_version_at_least(
|
||||
"0.14.0"
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
if model_config.quantization == "torchao":
|
||||
quant_config = get_quant_config(model_config, self.load_config)
|
||||
if (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and quant_config.is_checkpoint_torchao_serialized
|
||||
and torchao_version_at_least("0.14.0")
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
# set to True, it means that this is either offline quantization case
|
||||
# or the first run of online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||
model, "weight_metadata_and_attr_saved", False
|
||||
)
|
||||
|
||||
if model_config.quantization is None:
|
||||
# model is not quantized
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model)
|
||||
)
|
||||
elif offline_quantization_or_first_run_of_online_quantization:
|
||||
# case 1: offline quantized checkpoint
|
||||
# case 2: Step I1 first run of weight loading with
|
||||
# online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model)
|
||||
)
|
||||
else:
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
load_weights_and_online_quantize,
|
||||
)
|
||||
|
||||
# subsequent runs of weight loading with online
|
||||
# quantization
|
||||
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
|
||||
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -56,6 +56,9 @@ logger = init_logger(__name__)
|
||||
# R4. quantize weights (by calling process_weights_after_loading),
|
||||
# also set `process_weights_after_loading_already_called` to
|
||||
# True to stop it from running again
|
||||
# R5. (workaround for cudagraph), we restore the weight params to original quantized
|
||||
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
|
||||
# the weight update work well with cudagraph
|
||||
# process_weights_after_loading (if called):
|
||||
# this will be skipped since it's already ran in
|
||||
# load_weights
|
||||
@@ -69,14 +72,6 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||
if model_config.quantization != "torchao":
|
||||
return
|
||||
|
||||
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||
# In case `process_weights_after_loading` is called multiple times
|
||||
# we'll skip it at later times
|
||||
logger.warning(
|
||||
"process_weights_after_loading already called for model %s", model
|
||||
)
|
||||
return
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
||||
|
||||
quant_config = get_quant_config(model_config, None)
|
||||
@@ -137,6 +132,7 @@ def maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||
else:
|
||||
model.recorded_weight_attr[name][key] = attr
|
||||
# mark the metadata and attributes saved so we don't run it again
|
||||
model._model_config = model_config
|
||||
model.weight_metadata_and_attr_saved = True
|
||||
|
||||
|
||||
@@ -148,77 +144,132 @@ def _bond_method_to_cls(func, obj):
|
||||
return types.MethodType(func, obj)
|
||||
|
||||
|
||||
def load_weights_and_online_quantize(
|
||||
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
|
||||
) -> set[str]:
|
||||
def support_quantized_model_reload_from_hp_weights(original_load_weights):
|
||||
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
|
||||
reloading high precision (bfloat16/float16/float32) weight for an already quantized
|
||||
model, this involves restoring the weights to a high precision weights and
|
||||
then online quantize the weights
|
||||
"""
|
||||
# online quantization, right now only enabled for
|
||||
# torchao
|
||||
# R1, R2, R3, R4 in the Notes
|
||||
# R1, R2, R3, R4, R5 in the Notes
|
||||
|
||||
# TODO: Add fp8 support
|
||||
assert model_config.quantization == "torchao", (
|
||||
"online quantization is only enabled for torchao currently"
|
||||
)
|
||||
# TODO: use create_weights to restore the weights to original state
|
||||
def patched_model_load_weights(
|
||||
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
|
||||
) -> set[str]:
|
||||
model = auto_weight_loader.module
|
||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||
model, "weight_metadata_and_attr_saved", False
|
||||
)
|
||||
|
||||
# Step R1: First restore the quantized weights to original bfloat16
|
||||
# weights, with original metadata (shape, dtype, device)
|
||||
# and attributes, so that bfloat16 weights can be loaded properly
|
||||
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
|
||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
model_device = None
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
# set to True, it means that this is either offline quantization case
|
||||
# or the first run of online quantization
|
||||
# see Notes in this file for more details
|
||||
if offline_quantization_or_first_run_of_online_quantization:
|
||||
# case 1: offline quantized checkpoint
|
||||
# case 2: Step I1 first run of weight loading with
|
||||
# online quantization
|
||||
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
|
||||
|
||||
# Step R2: recover the parameter to the state before first loading
|
||||
for name, d in model.original_weights_rebuild_keys.items():
|
||||
_shape = d["shape"]
|
||||
_dtype = d["dtype"]
|
||||
_device = d["device"]
|
||||
model_config = model._model_config
|
||||
|
||||
# TODO: Add fp8 support
|
||||
assert model_config.quantization == "torchao", (
|
||||
"online quantization is only enabled for torchao currently"
|
||||
)
|
||||
# TODO: use create_weights to restore the weights to original state
|
||||
|
||||
# Step R1: First restore the quantized weights to original bfloat16
|
||||
# weights, with original metadata (shape, dtype, device)
|
||||
# and attributes, so that bfloat16 weights can be loaded properly
|
||||
# TODO: maybe set remove_duplicate to True?
|
||||
original_quantized_weight_dict = dict(
|
||||
model.named_parameters(remove_duplicate=False)
|
||||
)
|
||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
model_device = None
|
||||
|
||||
for name, d in model.original_weights_rebuild_keys.items():
|
||||
_shape = d["shape"]
|
||||
_dtype = d["dtype"]
|
||||
_device = d["device"]
|
||||
if model_device is not None:
|
||||
assert model_device == _device, (
|
||||
"Expecting all weights "
|
||||
"to be in the same device for now, got both: "
|
||||
f"{model_device} and {_device}"
|
||||
)
|
||||
else:
|
||||
model_device = _device
|
||||
|
||||
if name in original_quantized_weight_dict:
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
setattr(
|
||||
module,
|
||||
weight_name,
|
||||
torch.nn.Parameter(
|
||||
torch.empty(_shape, dtype=_dtype, device=_device),
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Step R2: recover the weight attributes to the state before first loading
|
||||
# recorded_weight_attr is
|
||||
# {"weight_name": {"weight_attr_key": attr}}
|
||||
# e.g.
|
||||
# {
|
||||
# {
|
||||
# "layer.0.weight": {
|
||||
# "weight_loader": weight_loader_function_object,
|
||||
# "input_dim": 0, ...
|
||||
# },
|
||||
# "layer.1.weight": ...,
|
||||
# }
|
||||
# }
|
||||
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
|
||||
for attr_name, attr in weight_attr_dict.items():
|
||||
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
weight = getattr(module, weight_name)
|
||||
if not hasattr(weight, attr_name):
|
||||
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
||||
|
||||
# Step R3: reload bfloat16 / high precision weights
|
||||
updated_params = original_load_weights(
|
||||
auto_weight_loader, weights, mapper=mapper
|
||||
)
|
||||
|
||||
# Step R4: online quantize the weights
|
||||
# manually process weights after loading
|
||||
model.process_weights_after_loading_already_called = False
|
||||
if model_device is not None:
|
||||
assert model_device == _device, (
|
||||
"Expecting all weights "
|
||||
"to be in the same device for now, got both: "
|
||||
f"{model_device} and {_device}"
|
||||
)
|
||||
process_weights_after_loading(model, model_config, model_device)
|
||||
else:
|
||||
model_device = _device
|
||||
|
||||
if name in existing_param_names:
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
setattr(
|
||||
module,
|
||||
weight_name,
|
||||
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
|
||||
logger.warning_once(
|
||||
"model_device is None, skip calling process_weights_after_loading"
|
||||
)
|
||||
|
||||
# recorded_weight_attr is
|
||||
# {"weight_name": {"weight_attr_key": attr}}
|
||||
# e.g.
|
||||
# {
|
||||
# {
|
||||
# "layer.0.weight": {
|
||||
# "weight_loader": weight_loader_function_object,
|
||||
# "input_dim": 0, ...
|
||||
# },
|
||||
# "layer.1.weight": ...,
|
||||
# }
|
||||
# }
|
||||
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
|
||||
for attr_name, attr in weight_attr_dict.items():
|
||||
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
weight = getattr(module, weight_name)
|
||||
if not hasattr(weight, attr_name):
|
||||
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
||||
# Step R5 (workaround for cudagraph): restore the original quantized weights
|
||||
# and do a copy_ of the currents weights to the original weights
|
||||
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
|
||||
for name in model.original_weights_rebuild_keys:
|
||||
if name in original_quantized_weight_dict:
|
||||
original_quantized_weight = original_quantized_weight_dict[name]
|
||||
updated_quantized_weight = updated_quantized_weights[name]
|
||||
|
||||
# Step I1: reload bfloat16 / high precision weights
|
||||
loaded_weights = model.load_weights(
|
||||
model_loader.get_all_weights(model_config, model)
|
||||
)
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
setattr(module, weight_name, original_quantized_weight)
|
||||
with torch.no_grad():
|
||||
original_quantized_weight.copy_(updated_quantized_weight)
|
||||
|
||||
# Step I2: online quantize the weights
|
||||
# manually process weights after loading
|
||||
model.process_weights_after_loading_already_called = False
|
||||
process_weights_after_loading(model, model_config, model_device)
|
||||
model.process_weights_after_loading_already_called = True
|
||||
return loaded_weights
|
||||
del original_quantized_weight_dict
|
||||
del named_modules
|
||||
del updated_quantized_weight
|
||||
|
||||
model.process_weights_after_loading_already_called = True
|
||||
return updated_params
|
||||
|
||||
return patched_model_load_weights
|
||||
|
||||
@@ -88,6 +88,14 @@ def initialize_model(
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||
# In case `process_weights_after_loading` is called multiple times
|
||||
# we'll skip it at later times
|
||||
logger.debug_once(
|
||||
"process_weights_after_loading already called for model %s", model
|
||||
)
|
||||
return
|
||||
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading,
|
||||
|
||||
@@ -21,6 +21,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
support_quantized_model_reload_from_hp_weights,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||
from vllm.multimodal import NestedTensors
|
||||
@@ -316,6 +319,7 @@ class AutoWeightsLoader:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@support_quantized_model_reload_from_hp_weights
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
|
||||
Reference in New Issue
Block a user