[Core][1/N] Support send/recv in PyNCCL Groups (#4988)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu
2024-05-23 09:54:48 -07:00
committed by GitHub
parent 2ba80bed27
commit 5eda2ea02a
5 changed files with 170 additions and 17 deletions

View File

@@ -126,6 +126,40 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def send(self,
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
def recv(self,
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
@contextmanager
def change_state(self,
enable: Optional[bool] = None,

View File

@@ -151,6 +151,22 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function("ncclSend", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function("ncclRecv", ncclResult_t, [
buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,
ncclComm_t, cudaStream_t
]),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
@@ -248,6 +264,16 @@ class NCCLLibrary:
datatype, op, comm,
stream))
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
dest, comm, stream))
def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src,
comm, stream))
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))