[Core] Support inplace model weights loading (#18745)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-06-02 02:38:50 -07:00
committed by GitHub
parent b9f61e1387
commit 9760fd8f6a
13 changed files with 240 additions and 288 deletions

View File

@@ -14,7 +14,7 @@ from huggingface_hub import HfApi
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
# yapf: enable
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.linear import (LinearBase,
RowParallelLinear)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (ParamMapping,
initialize_model,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
@@ -408,8 +407,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None:
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
@@ -568,15 +566,3 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
self._load_weights(model_config, model)
return model.eval()