[Core][Distributed] refactor pynccl (#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
This commit is contained in:
@@ -1,26 +1,4 @@
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
@@ -28,217 +6,70 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||
ncclRedOpTypeEnum, ncclUniqueId)
|
||||
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
so_file = find_nccl_library()
|
||||
|
||||
try:
|
||||
# load the library in another process.
|
||||
# if it core dumps, it will not crash the current process
|
||||
nccl_integrity_check(so_file)
|
||||
nccl = ctypes.CDLL(so_file)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"One solution is to download libnccl2 version 2.18 from "
|
||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
||||
"and extract the libnccl.so.2 file. If you already have the "
|
||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
|
||||
_c_ncclGetErrorString = nccl.ncclGetErrorString
|
||||
_c_ncclGetErrorString.restype = ctypes.c_char_p
|
||||
_c_ncclGetErrorString.argtypes = [ncclResult_t]
|
||||
|
||||
|
||||
def NCCL_CHECK(result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = _c_ncclGetErrorString(result)
|
||||
error_str = error_str.decode("utf-8")
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
_c_ncclGetVersion = nccl.ncclGetVersion
|
||||
_c_ncclGetVersion.restype = ctypes.c_int
|
||||
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
|
||||
|
||||
def ncclGetVersion() -> str:
|
||||
version = ctypes.c_int()
|
||||
NCCL_CHECK(_c_ncclGetVersion(ctypes.byref(version)))
|
||||
# something like 21903 --> "2.19.3"
|
||||
version_str = str(version.value)
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
|
||||
class NcclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
|
||||
_c_ncclGetUniqueId.restype = ctypes.c_int
|
||||
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
|
||||
|
||||
|
||||
def ncclGetUniqueId() -> NcclUniqueId:
|
||||
unique_id = NcclUniqueId()
|
||||
NCCL_CHECK(_c_ncclGetUniqueId(ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
_c_ncclCommInitRank = nccl.ncclCommInitRank
|
||||
_c_ncclCommInitRank.restype = ctypes.c_int
|
||||
_c_ncclCommInitRank.argtypes = [
|
||||
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
|
||||
]
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# udaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument is a pointer
|
||||
_c_ncclAllReduce = nccl.ncclAllReduce
|
||||
_c_ncclAllReduce.restype = ctypes.c_int
|
||||
_c_ncclAllReduce.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
|
||||
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
|
||||
]
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# equivalent to c declaration:
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
_c_ncclCommDestroy = nccl.ncclCommDestroy
|
||||
_c_ncclCommDestroy.restype = ctypes.c_int
|
||||
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
|
||||
|
||||
|
||||
class NCCLCommunicator:
|
||||
class PyNcclCommunicator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
library_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the NCCLCommunicator to. If None,
|
||||
device: the device to bind the PyNcclCommunicator to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
library_path: the path to the NCCL library. If None, it will
|
||||
use the default library path.
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device.
|
||||
"""
|
||||
assert dist.is_initialized()
|
||||
group = get_cpu_world_group() if group is None else group
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"NCCLCommunicator should be attached to a non-NCCL group.")
|
||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||
self.group = group
|
||||
# note: this rank is the rank in the group
|
||||
self.rank = dist.get_rank(group)
|
||||
self.world_size = dist.get_world_size(group)
|
||||
|
||||
# if world_size == 1, no need to create communicator
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
except Exception:
|
||||
# disable because of missing NCCL library
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
self.disabled = False
|
||||
|
||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||
|
||||
if self.rank == 0:
|
||||
self.unique_id = ncclGetUniqueId()
|
||||
# get the unique id from NCCL
|
||||
self.unique_id = self.nccl.ncclGetUniqueId()
|
||||
else:
|
||||
self.unique_id = NcclUniqueId()
|
||||
# construct an empty unique id
|
||||
self.unique_id = ncclUniqueId()
|
||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||
ranks = dist.get_process_group_ranks(group)
|
||||
# arg `src` in `broadcast` is the global rank
|
||||
@@ -246,7 +77,6 @@ class NCCLCommunicator:
|
||||
byte_list = tensor.tolist()
|
||||
for i, byte in enumerate(byte_list):
|
||||
self.unique_id.internal[i] = byte
|
||||
self.comm = ctypes.c_void_p()
|
||||
if device is None:
|
||||
local_rank = get_local_rank()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
@@ -261,15 +91,25 @@ class NCCLCommunicator:
|
||||
# `torch.cuda.device` is a context manager that changes the
|
||||
# current cuda device to the specified one
|
||||
with torch.cuda.device(device):
|
||||
NCCL_CHECK(
|
||||
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||
self.unique_id, self.rank))
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
self.all_reduce(torch.zeros(1, device=device))
|
||||
self.stream.synchronize()
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
@@ -278,10 +118,32 @@ class NCCLCommunicator:
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
NCCL_CHECK(
|
||||
_c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
|
||||
ctypes.c_void_p(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
ctypes.c_void_p(stream.cuda_stream)))
|
||||
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self,
|
||||
enable: Optional[bool] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetVersion)
|
||||
except Exception as e:
|
||||
# in non-NVIDIA environments, we can't import the nccl module
|
||||
# e.g. when running on machines with AMD GPUs
|
||||
logger.info("Failed to import NCCL library: %s", e)
|
||||
logger.info("It is expected if you are not running on NVIDIA GPUs.")
|
||||
pass
|
||||
|
||||
comm: Optional["NCCLCommunicator"] = None
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""Returns whether the NCCL backend is initialized."""
|
||||
return comm is not None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_pynccl_stream(stream: torch.cuda.Stream):
|
||||
"""Set the cuda stream for communication"""
|
||||
try:
|
||||
assert comm is not None
|
||||
comm.stream = stream
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def init_process_group(group: Optional[ProcessGroup] = None) -> None:
|
||||
assert not is_initialized()
|
||||
global comm
|
||||
logger.info("vLLM is using nccl==%s", ncclGetVersion())
|
||||
comm = NCCLCommunicator(group=group)
|
||||
|
||||
|
||||
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||
"""All-reduces the input tensor across the process group."""
|
||||
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||
assert comm is not None
|
||||
comm.all_reduce(input_, op)
|
||||
|
||||
|
||||
def destroy_process_group() -> None:
|
||||
global comm
|
||||
comm = None
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Returns the world size."""
|
||||
assert comm is not None
|
||||
return comm.world_size
|
||||
|
||||
|
||||
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
|
||||
return comm
|
||||
258
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
258
vllm/distributed/device_communicators/pynccl_wrapper.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# This file is a pure Python wrapper for the NCCL library.
|
||||
# The main purpose is to use NCCL combined with CUDA graph.
|
||||
# Before writing this script, we tried the following approach:
|
||||
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
||||
# often gets stuck when initializing the NCCL communicator.
|
||||
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
||||
# contains many other potential cuda APIs, that are not allowed during
|
||||
# capturing the CUDA graph. For further details, please check
|
||||
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
||||
#
|
||||
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
||||
# doable, but we often encounter issues related with nccl versions, and need
|
||||
# to switch between different versions of NCCL. See
|
||||
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
||||
# A C/C++ binding is not flexible enough to handle this. It requires
|
||||
# recompilation of the code every time we want to switch between different
|
||||
# versions. This current implementation, with a **pure** Python wrapper, is
|
||||
# more flexible. We can easily switch between different versions of NCCL by
|
||||
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
||||
# variable in the code.
|
||||
|
||||
import ctypes
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from nccl to Python ===
|
||||
# for the original nccl definition, please check
|
||||
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
||||
|
||||
ncclResult_t = ctypes.c_int
|
||||
ncclComm_t = ctypes.c_void_p
|
||||
|
||||
|
||||
class ncclUniqueId(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
buffer_type = ctypes.c_void_p
|
||||
|
||||
ncclDataType_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclDataTypeEnum:
|
||||
ncclInt8 = 0
|
||||
ncclChar = 0
|
||||
ncclUint8 = 1
|
||||
ncclInt32 = 2
|
||||
ncclInt = 2
|
||||
ncclUint32 = 3
|
||||
ncclInt64 = 4
|
||||
ncclUint64 = 5
|
||||
ncclFloat16 = 6
|
||||
ncclHalf = 6
|
||||
ncclFloat32 = 7
|
||||
ncclFloat = 7
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
if dtype == torch.int8:
|
||||
return cls.ncclInt8
|
||||
if dtype == torch.uint8:
|
||||
return cls.ncclUint8
|
||||
if dtype == torch.int32:
|
||||
return cls.ncclInt32
|
||||
if dtype == torch.int64:
|
||||
return cls.ncclInt64
|
||||
if dtype == torch.float16:
|
||||
return cls.ncclFloat16
|
||||
if dtype == torch.float32:
|
||||
return cls.ncclFloat32
|
||||
if dtype == torch.float64:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
|
||||
ncclRedOp_t = ctypes.c_int
|
||||
|
||||
|
||||
class ncclRedOpTypeEnum:
|
||||
ncclSum = 0
|
||||
ncclProd = 1
|
||||
ncclMax = 2
|
||||
ncclMin = 3
|
||||
ncclAvg = 4
|
||||
ncclNumOps = 5
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, op: ReduceOp) -> int:
|
||||
if op == ReduceOp.SUM:
|
||||
return cls.ncclSum
|
||||
if op == ReduceOp.PRODUCT:
|
||||
return cls.ncclProd
|
||||
if op == ReduceOp.MAX:
|
||||
return cls.ncclMax
|
||||
if op == ReduceOp.MIN:
|
||||
return cls.ncclMin
|
||||
if op == ReduceOp.AVG:
|
||||
return cls.ncclAvg
|
||||
raise ValueError(f"Unsupported op: {op}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class NCCLLibrary:
|
||||
exported_functions = [
|
||||
# const char* ncclGetErrorString(ncclResult_t result)
|
||||
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
||||
# ncclResult_t ncclGetVersion(int *version);
|
||||
Function("ncclGetVersion", ncclResult_t,
|
||||
[ctypes.POINTER(ctypes.c_int)]),
|
||||
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
||||
Function("ncclGetUniqueId", ncclResult_t,
|
||||
[ctypes.POINTER(ncclUniqueId)]),
|
||||
# ncclResult_t ncclCommInitRank(
|
||||
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
||||
# note that ncclComm_t is a pointer type, so the first argument
|
||||
# is a pointer to a pointer
|
||||
Function("ncclCommInitRank", ncclResult_t, [
|
||||
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
||||
ctypes.c_int
|
||||
]),
|
||||
# ncclResult_t ncclAllReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclAllReduce", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# be cautious! this is a collective call, it will block until all
|
||||
# processes in the communicator have called this function.
|
||||
# because Python object destruction can happen in random order,
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
# load the library in another process.
|
||||
# if it core dumps, it will not crash the current process
|
||||
nccl_integrity_check(so_file)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"One solution is to download libnccl2 version 2.18 from "
|
||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
||||
"and extract the libnccl.so.2 file. If you already have the "
|
||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
||||
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
||||
|
||||
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.ncclGetErrorString(result)
|
||||
raise RuntimeError(f"NCCL error: {error_str}")
|
||||
|
||||
def ncclGetVersion(self) -> str:
|
||||
version = ctypes.c_int()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||
version_str = str(version.value)
|
||||
# something like 21903 --> "2.19.3"
|
||||
major = version_str[0].lstrip("0")
|
||||
minor = version_str[1:3].lstrip("0")
|
||||
patch = version_str[3:].lstrip("0")
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
def ncclGetUniqueId(self) -> ncclUniqueId:
|
||||
unique_id = ncclUniqueId()
|
||||
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
|
||||
world_size, unique_id,
|
||||
rank))
|
||||
return comm
|
||||
|
||||
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||
"ncclComm_t", "cudaStream_t", "buffer_type"
|
||||
]
|
||||
Reference in New Issue
Block a user