[Mypy] Part 3 fix typing for nested directories for most of directory (#4161)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -18,7 +18,7 @@ except ImportError:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CA_HANDLE = None
|
||||
_CA_HANDLE: Optional["CustomAllreduce"] = None
|
||||
_IS_CAPTURING = False
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
@@ -51,7 +51,7 @@ def init_custom_ar() -> None:
|
||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||
" is set.")
|
||||
return False
|
||||
return
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
@@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
ca_handle = get_handle()
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if ca_handle is None:
|
||||
return
|
||||
return None
|
||||
if is_capturing():
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if ca_handle.should_custom_ar(input):
|
||||
@@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_unreg(input)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
@@ -224,14 +226,14 @@ class CustomAllreduce:
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
all_data = [None] * self.world_size
|
||||
all_data: List[Optional[Any]] = [None] * self.world_size
|
||||
dist.all_gather_object(all_data, shard_data)
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0])
|
||||
offsets.append(all_data[i][1])
|
||||
handles.append(all_data[i][0]) # type: ignore
|
||||
offsets.append(all_data[i][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
|
||||
@@ -107,9 +107,10 @@ _c_ncclCommInitRank.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
||||
]
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
# enums
|
||||
class ncclDataType_t(ctypes.c_int):
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
@@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
@@ -148,7 +149,10 @@ class ncclDataType_t(ctypes.c_int):
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
class ncclRedOp_t(ctypes.c_int):
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
@@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
@@ -180,8 +184,8 @@ class ncclRedOp_t(ctypes.c_int):
|
||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
||||
_c_ncclAllReduce.restype = ctypes.c_int
|
||||
_c_ncclAllReduce.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
||||
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
]
|
||||
|
||||
# equivalent to c declaration:
|
||||
@@ -251,8 +255,8 @@ class NCCLCommunicator:
|
||||
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
ctypes.c_void_p(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataType_t.from_torch(tensor.dtype),
|
||||
ncclRedOp_t.from_torch(op), self.comm,
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
ctypes.c_void_p(stream.cuda_stream))
|
||||
assert result == 0
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def is_initialized() -> bool:
|
||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
try:
|
||||
assert comm is not None
|
||||
comm.stream = stream
|
||||
yield
|
||||
finally:
|
||||
@@ -52,6 +53,7 @@ def init_process_group(world_size: int,
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
assert comm is not None
|
||||
comm.all_reduce(input_, op)
|
||||
|
||||
|
||||
@@ -62,8 +64,9 @@ def destroy_process_group() -> None:
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
assert comm is not None
|
||||
return comm.world_size
|
||||
|
||||
|
||||
def get_nccl_backend():
|
||||
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
|
||||
return comm
|
||||
|
||||
Reference in New Issue
Block a user