[Bugfix] Fix for multinode crash on 4 PP (#6495)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
committed by
GitHub
parent
5bf35a91e4
commit
5fa6e9876e
@@ -224,13 +224,27 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
tp_driver_worker_ranks = []
|
||||
non_driver_worker_ranks = []
|
||||
for idx, rank in enumerate(worker_ranks[1:]):
|
||||
# We need to skip the driver worker, which we
|
||||
# do by skipping worker_ranks[0] which is always 0.
|
||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(self.workers[idx])
|
||||
tp_driver_worker_ranks.append(rank)
|
||||
else:
|
||||
self.non_driver_workers.append(self.workers[idx])
|
||||
non_driver_worker_ranks.append(rank)
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
self.tp_driver_workers = [
|
||||
worker for _, worker in sorted(
|
||||
zip(tp_driver_worker_ranks, self.tp_driver_workers))
|
||||
]
|
||||
self.non_driver_workers = [
|
||||
worker for _, worker in sorted(
|
||||
zip(non_driver_worker_ranks, self.non_driver_workers))
|
||||
]
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
|
||||
Reference in New Issue
Block a user