custom allreduce + torch.compile (#10121)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||
pynccl1.disabled = False
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||
pynccl2.disabled = False
|
||||
data = torch.tensor([rank]).cuda()
|
||||
pynccl1.all_reduce(data)
|
||||
pg1.barrier()
|
||||
|
||||
Reference in New Issue
Block a user