[Core] Centralize GPU Worker construction (#4419)

This commit is contained in:
Nick Hill
2024-04-30 18:06:34 -07:00
committed by GitHub
parent ee37328da0
commit 2e240c69a9
2 changed files with 47 additions and 68 deletions

View File

@@ -153,29 +153,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
))
init_worker_all_kwargs = [
self._get_worker_kwargs(
local_rank=node_workers[node_id].index(rank),
rank=rank,
distributed_init_method=distributed_init_method,
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device")
@@ -201,8 +186,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
return all_outputs[0]
def _run_workers(
self,