[Distributed] Add send and recv helpers (#5719)
This commit is contained in:
committed by
GitHub
parent
6c916ac8a8
commit
5d4d90536f
@@ -168,9 +168,13 @@ def send_recv_worker_fn():
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if pynccl_comm.rank == 0:
|
||||
pynccl_comm.send(tensor)
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) %
|
||||
pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor)
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 1
|
||||
|
||||
@@ -203,9 +207,13 @@ def multiple_send_recv_worker_fn():
|
||||
device=device)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.send(tensor)
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) %
|
||||
pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor)
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
result = tensor.mean().cpu().item()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert result == 1
|
||||
|
||||
Reference in New Issue
Block a user