diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index f68405d05..d6c38664f 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -50,10 +50,14 @@ class BaseModelLoader(ABC): device_config.device if load_config.device is None else load_config.device ) target_device = torch.device(load_device) - with set_default_torch_dtype(model_config.dtype), target_device: - model = initialize_model( - vllm_config=vllm_config, model_config=model_config, prefix=prefix - ) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model( + vllm_config=vllm_config, + model_config=model_config, + prefix=prefix, + ) + log_model_inspection(model) logger.debug("Loading weights on %s ...", load_device)