[Core] Refactor model loading code (#4097)

This commit is contained in:
Antoni Baum
2024-04-16 11:34:39 -07:00
committed by GitHub
parent 05434764cd
commit 69e1d2fb69
67 changed files with 1054 additions and 963 deletions

View File

@@ -147,6 +147,7 @@ class RayGPUExecutor(ExecutorBase):
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
load_config = copy.deepcopy(self.load_config)
device_config = copy.deepcopy(self.device_config)
lora_config = copy.deepcopy(self.lora_config)
cache_config = copy.deepcopy(self.cache_config)
@@ -165,12 +166,12 @@ class RayGPUExecutor(ExecutorBase):
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=lora_config,
vision_language_config=vision_language_config,
tensorizer_config=self.tensorizer_config,
))
# Initialize the driver worker with the Worker class.
@@ -187,7 +188,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
load_config=self.load_config,
is_driver_worker=True,
)