[Distributed] Add send and recv helpers (#5719)

This commit is contained in:
Murali Andoorveedu
2024-06-23 17:42:28 -04:00
committed by GitHub
parent 6c916ac8a8
commit 5d4d90536f
6 changed files with 278 additions and 24 deletions

View File

@@ -121,10 +121,7 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def send(self,
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
@@ -132,16 +129,11 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
def recv(self,
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
@@ -149,8 +141,6 @@ class PyNcclCommunicator:
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))