[Core] Support inplace model weights loading (#18745)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
@@ -1564,7 +1564,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
time_before_load = time.perf_counter()
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
self.model = model_loader.load_model(
|
||||
vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info(
|
||||
"Model was already initialized. Loading weights inplace..."
|
||||
)
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model,
|
||||
self.model_config,
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
@@ -171,7 +171,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
self.model: nn.Module # Set after load_model
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
@@ -419,7 +419,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
assert self.model is not None
|
||||
return self.model
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
@@ -936,7 +935,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
# model = get_model(vllm_config=self.vllm_config)
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
logger.info("Loading model from scratch...")
|
||||
model = model_loader.load_model(vllm_config=self.vllm_config,
|
||||
model_config=self.model_config)
|
||||
else:
|
||||
logger.info(
|
||||
"Model was already initialized. Loading weights inplace..."
|
||||
)
|
||||
model_loader.load_weights(self.model,
|
||||
model_config=self.model_config)
|
||||
if self.lora_config is not None:
|
||||
model = self.load_lora_model(model, self.model_config,
|
||||
self.scheduler_config,
|
||||
@@ -947,7 +957,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
# loading.
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
self.model = model
|
||||
if not hasattr(self, "model"):
|
||||
self.model = model
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user