[Bugfix] Fix for multinode crash on 4 PP (#6495)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu
2024-07-17 04:25:10 -04:00
committed by GitHub
parent 5bf35a91e4
commit 5fa6e9876e
2 changed files with 17 additions and 5 deletions

View File

@@ -224,13 +224,27 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
tp_driver_worker_ranks = []
non_driver_worker_ranks = []
for idx, rank in enumerate(worker_ranks[1:]):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[idx])
tp_driver_worker_ranks.append(rank)
else:
self.non_driver_workers.append(self.workers[idx])
non_driver_worker_ranks.append(rank)
# Enforce rank order for correct rank to return final output.
self.tp_driver_workers = [
worker for _, worker in sorted(
zip(tp_driver_worker_ranks, self.tp_driver_workers))
]
self.non_driver_workers = [
worker for _, worker in sorted(
zip(non_driver_worker_ranks, self.non_driver_workers))
]
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]