[distributed] remove pynccl's redundant stream (#11744)
This commit is contained in:
@@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
|
|||||||
# run something in the default stream to initialize torch engine
|
# run something in the default stream to initialize torch engine
|
||||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.graph(
|
with torch.cuda.graph(graph), \
|
||||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
pynccl_comm.change_state(enable=True):
|
||||||
enable=True):
|
|
||||||
a_out = pynccl_comm.all_reduce(a)
|
a_out = pynccl_comm.all_reduce(a)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ class PyNcclCommunicator:
|
|||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
self.available = False
|
self.available = False
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.stream = None
|
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
self.nccl = NCCLLibrary(library_path)
|
self.nccl = NCCLLibrary(library_path)
|
||||||
@@ -60,7 +59,6 @@ class PyNcclCommunicator:
|
|||||||
# e.g. in a non-GPU environment
|
# e.g. in a non-GPU environment
|
||||||
self.available = False
|
self.available = False
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
self.stream = None
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.available = True
|
self.available = True
|
||||||
@@ -98,12 +96,12 @@ class PyNcclCommunicator:
|
|||||||
with torch.cuda.device(device):
|
with torch.cuda.device(device):
|
||||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||||
self.world_size, self.unique_id, self.rank)
|
self.world_size, self.unique_id, self.rank)
|
||||||
self.stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
|
stream = torch.cuda.current_stream()
|
||||||
# A small all_reduce for warmup.
|
# A small all_reduce for warmup.
|
||||||
data = torch.zeros(1, device=device)
|
data = torch.zeros(1, device=device)
|
||||||
self.all_reduce(data)
|
self.all_reduce(data)
|
||||||
self.stream.synchronize()
|
stream.synchronize()
|
||||||
del data
|
del data
|
||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
@@ -122,7 +120,7 @@ class PyNcclCommunicator:
|
|||||||
out_tensor = torch.empty_like(in_tensor)
|
out_tensor = torch.empty_like(in_tensor)
|
||||||
|
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = torch.cuda.current_stream()
|
||||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||||
buffer_type(out_tensor.data_ptr()),
|
buffer_type(out_tensor.data_ptr()),
|
||||||
in_tensor.numel(),
|
in_tensor.numel(),
|
||||||
@@ -144,7 +142,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {input_tensor.device}")
|
f"but the input tensor is on {input_tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = torch.cuda.current_stream()
|
||||||
self.nccl.ncclAllGather(
|
self.nccl.ncclAllGather(
|
||||||
buffer_type(input_tensor.data_ptr()),
|
buffer_type(input_tensor.data_ptr()),
|
||||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||||
@@ -165,7 +163,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {input_tensor.device}")
|
f"but the input tensor is on {input_tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = torch.cuda.current_stream()
|
||||||
self.nccl.ncclReduceScatter(
|
self.nccl.ncclReduceScatter(
|
||||||
buffer_type(input_tensor.data_ptr()),
|
buffer_type(input_tensor.data_ptr()),
|
||||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||||
@@ -180,7 +178,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = torch.cuda.current_stream()
|
||||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
@@ -192,7 +190,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = torch.cuda.current_stream()
|
||||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
@@ -204,7 +202,7 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {tensor.device}")
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = torch.cuda.current_stream()
|
||||||
if src == self.rank:
|
if src == self.rank:
|
||||||
sendbuff = buffer_type(tensor.data_ptr())
|
sendbuff = buffer_type(tensor.data_ptr())
|
||||||
# NCCL requires the sender also to have a receive buffer
|
# NCCL requires the sender also to have a receive buffer
|
||||||
@@ -217,9 +215,7 @@ class PyNcclCommunicator:
|
|||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def change_state(self,
|
def change_state(self, enable: Optional[bool] = None):
|
||||||
enable: Optional[bool] = None,
|
|
||||||
stream: Optional[torch.cuda.Stream] = None):
|
|
||||||
"""
|
"""
|
||||||
A context manager to change the state of the communicator.
|
A context manager to change the state of the communicator.
|
||||||
"""
|
"""
|
||||||
@@ -227,15 +223,9 @@ class PyNcclCommunicator:
|
|||||||
# guess a default value when not specified
|
# guess a default value when not specified
|
||||||
enable = self.available
|
enable = self.available
|
||||||
|
|
||||||
if stream is None:
|
|
||||||
stream = self.stream
|
|
||||||
|
|
||||||
old_disable = self.disabled
|
old_disable = self.disabled
|
||||||
old_stream = self.stream
|
|
||||||
|
|
||||||
self.stream = stream
|
|
||||||
self.disabled = not enable
|
self.disabled = not enable
|
||||||
yield
|
yield
|
||||||
|
|
||||||
self.disabled = old_disable
|
self.disabled = old_disable
|
||||||
self.stream = old_stream
|
|
||||||
|
|||||||
@@ -310,8 +310,7 @@ class GroupCoordinator:
|
|||||||
if not pynccl_comm:
|
if not pynccl_comm:
|
||||||
maybe_pynccl_context = nullcontext()
|
maybe_pynccl_context = nullcontext()
|
||||||
else:
|
else:
|
||||||
maybe_pynccl_context = pynccl_comm.change_state(
|
maybe_pynccl_context = pynccl_comm.change_state()
|
||||||
stream=torch.cuda.current_stream())
|
|
||||||
with maybe_pynccl_context:
|
with maybe_pynccl_context:
|
||||||
yield graph_capture_context
|
yield graph_capture_context
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user