diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bf226f661..c88af56e1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader): local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) + + _process_weights_after_loading(model, model_config, target_device) return model