custom allreduce + torch.compile (#10121)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Sage Moore
2024-11-26 00:00:16 -06:00
committed by GitHub
parent 519e8e4182
commit 9a88f89799
6 changed files with 62 additions and 104 deletions

View File

@@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()