Move online quantization to model.load_weights (#26327)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
fastsafetensors_weights_iterator,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
get_quant_config,
|
||||
maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
@@ -273,42 +274,17 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao" and torchao_version_at_least(
|
||||
"0.14.0"
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
if model_config.quantization == "torchao":
|
||||
quant_config = get_quant_config(model_config, self.load_config)
|
||||
if (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and quant_config.is_checkpoint_torchao_serialized
|
||||
and torchao_version_at_least("0.14.0")
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
# set to True, it means that this is either offline quantization case
|
||||
# or the first run of online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||
model, "weight_metadata_and_attr_saved", False
|
||||
)
|
||||
|
||||
if model_config.quantization is None:
|
||||
# model is not quantized
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model)
|
||||
)
|
||||
elif offline_quantization_or_first_run_of_online_quantization:
|
||||
# case 1: offline quantized checkpoint
|
||||
# case 2: Step I1 first run of weight loading with
|
||||
# online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model)
|
||||
)
|
||||
else:
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
load_weights_and_online_quantize,
|
||||
)
|
||||
|
||||
# subsequent runs of weight loading with online
|
||||
# quantization
|
||||
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
|
||||
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
|
||||
Reference in New Issue
Block a user