[Core] Support inplace model weights loading (#18745)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -14,7 +14,7 @@ from huggingface_hub import HfApi
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import (LinearBase,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (ParamMapping,
|
||||
initialize_model,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
@@ -408,8 +407,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _load_weights(self, model_config: ModelConfig,
|
||||
model: nn.Module) -> None:
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
@@ -568,15 +566,3 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
self._load_weights(model_config, model)
|
||||
|
||||
return model.eval()
|
||||
|
||||
Reference in New Issue
Block a user