[distributed] remove pynccl's redundant stream (#11744)

This commit is contained in:
cennn
2025-01-05 23:09:11 +08:00
committed by GitHub
parent 4068f4b5b5
commit 635b897246
3 changed files with 12 additions and 24 deletions

View File

@@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine # run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph( with torch.cuda.graph(graph), \
graph, stream=pynccl_comm.stream), pynccl_comm.change_state( pynccl_comm.change_state(enable=True):
enable=True):
a_out = pynccl_comm.all_reduce(a) a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize() torch.cuda.synchronize()
graph.replay() graph.replay()

View File

@@ -51,7 +51,6 @@ class PyNcclCommunicator:
if self.world_size == 1: if self.world_size == 1:
self.available = False self.available = False
self.disabled = True self.disabled = True
self.stream = None
return return
try: try:
self.nccl = NCCLLibrary(library_path) self.nccl = NCCLLibrary(library_path)
@@ -60,7 +59,6 @@ class PyNcclCommunicator:
# e.g. in a non-GPU environment # e.g. in a non-GPU environment
self.available = False self.available = False
self.disabled = True self.disabled = True
self.stream = None
return return
self.available = True self.available = True
@@ -98,12 +96,12 @@ class PyNcclCommunicator:
with torch.cuda.device(device): with torch.cuda.device(device):
self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank) self.world_size, self.unique_id, self.rank)
self.stream = torch.cuda.Stream()
stream = torch.cuda.current_stream()
# A small all_reduce for warmup. # A small all_reduce for warmup.
data = torch.zeros(1, device=device) data = torch.zeros(1, device=device)
self.all_reduce(data) self.all_reduce(data)
self.stream.synchronize() stream.synchronize()
del data del data
def all_reduce(self, def all_reduce(self,
@@ -122,7 +120,7 @@ class PyNcclCommunicator:
out_tensor = torch.empty_like(in_tensor) out_tensor = torch.empty_like(in_tensor)
if stream is None: if stream is None:
stream = self.stream stream = torch.cuda.current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()),
in_tensor.numel(), in_tensor.numel(),
@@ -144,7 +142,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}") f"but the input tensor is on {input_tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = torch.cuda.current_stream()
self.nccl.ncclAllGather( self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()), buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(), buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
@@ -165,7 +163,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}") f"but the input tensor is on {input_tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = torch.cuda.current_stream()
self.nccl.ncclReduceScatter( self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()), buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(), buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
@@ -180,7 +178,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = torch.cuda.current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst, ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
@@ -192,7 +190,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = torch.cuda.current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src, ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
@@ -204,7 +202,7 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, " f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}") f"but the input tensor is on {tensor.device}")
if stream is None: if stream is None:
stream = self.stream stream = torch.cuda.current_stream()
if src == self.rank: if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr()) sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer # NCCL requires the sender also to have a receive buffer
@@ -217,9 +215,7 @@ class PyNcclCommunicator:
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
@contextmanager @contextmanager
def change_state(self, def change_state(self, enable: Optional[bool] = None):
enable: Optional[bool] = None,
stream: Optional[torch.cuda.Stream] = None):
""" """
A context manager to change the state of the communicator. A context manager to change the state of the communicator.
""" """
@@ -227,15 +223,9 @@ class PyNcclCommunicator:
# guess a default value when not specified # guess a default value when not specified
enable = self.available enable = self.available
if stream is None:
stream = self.stream
old_disable = self.disabled old_disable = self.disabled
old_stream = self.stream
self.stream = stream
self.disabled = not enable self.disabled = not enable
yield yield
self.disabled = old_disable self.disabled = old_disable
self.stream = old_stream

View File

@@ -310,8 +310,7 @@ class GroupCoordinator:
if not pynccl_comm: if not pynccl_comm:
maybe_pynccl_context = nullcontext() maybe_pynccl_context = nullcontext()
else: else:
maybe_pynccl_context = pynccl_comm.change_state( maybe_pynccl_context = pynccl_comm.change_state()
stream=torch.cuda.current_stream())
with maybe_pynccl_context: with maybe_pynccl_context:
yield graph_capture_context yield graph_capture_context