[Core][Distributed] enable multiple tp group (#4512)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -232,6 +232,7 @@ class NCCLCommunicator:
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"NCCLCommunicator should be attached to a non-NCCL group.")
|
||||
self.group = group
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
if self.rank == 0:
|
||||
@@ -239,7 +240,9 @@ class NCCLCommunicator:
|
||||
else:
|
||||
self.unique_id = NcclUniqueId()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
dist.broadcast(tensor, src=0, group=group)
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
dist.broadcast(tensor, src=ranks[0], group=group)
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
|
||||
Reference in New Issue
Block a user