[distributed] remove pynccl's redundant change_state (#11749)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
@@ -213,19 +212,3 @@ class PyNcclCommunicator:
|
||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self, enable: Optional[bool] = None):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
old_disable = self.disabled
|
||||
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
|
||||
@@ -305,14 +305,7 @@ class GroupCoordinator:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.cuda.stream(stream), maybe_ca_context:
|
||||
pynccl_comm = self.pynccl_comm
|
||||
maybe_pynccl_context: Any
|
||||
if not pynccl_comm:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state()
|
||||
with maybe_pynccl_context:
|
||||
yield graph_capture_context
|
||||
yield graph_capture_context
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user