[distributed] remove pynccl's redundant change_state (#11749)

This commit is contained in:
cennn
2025-01-06 09:05:48 +08:00
committed by GitHub
parent 33fc1e2e86
commit 9e764e7b10
3 changed files with 28 additions and 62 deletions

View File

@@ -1,4 +1,3 @@
from contextlib import contextmanager
from typing import Optional, Union
# ===================== import region =====================
@@ -213,19 +212,3 @@ class PyNcclCommunicator:
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
@contextmanager
def change_state(self, enable: Optional[bool] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
old_disable = self.disabled
self.disabled = not enable
yield
self.disabled = old_disable

View File

@@ -305,14 +305,7 @@ class GroupCoordinator:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context:
pynccl_comm = self.pynccl_comm
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state()
with maybe_pynccl_context:
yield graph_capture_context
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""