[Core] Support inplace model weights loading (#18745)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -12,11 +12,9 @@ from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||
@@ -264,32 +262,20 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
Reference in New Issue
Block a user