[Core] Refactor model loading code (#4097)
This commit is contained in:
@@ -147,6 +147,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
load_config = copy.deepcopy(self.load_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
lora_config = copy.deepcopy(self.lora_config)
|
||||
cache_config = copy.deepcopy(self.cache_config)
|
||||
@@ -165,12 +166,12 @@ class RayGPUExecutor(ExecutorBase):
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
load_config=load_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
))
|
||||
|
||||
# Initialize the driver worker with the Worker class.
|
||||
@@ -187,7 +188,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
load_config=self.load_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user