[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)

This commit is contained in:
youkaichao
2024-05-12 17:47:59 -07:00
committed by GitHub
parent a7be4d0072
commit 702bee461f
10 changed files with 327 additions and 226 deletions

View File

@@ -1,155 +1,43 @@
from contextlib import contextmanager
from typing import Any, List, Optional
from typing import Any, List, Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.distributed.parallel_state import (
get_local_rank, get_tensor_model_parallel_cpu_group)
from vllm.logger import init_logger
try:
import pynvml
from vllm._C import custom_ar
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
except ImportError:
# For AMD GPUs
custom_ar = None
pynvml = None
@contextmanager
def _nvml():
try:
yield
finally:
pass
logger = init_logger(__name__)
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE
if _CA_HANDLE is not None:
return
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warning(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return
# we only use a subset of GPUs here
# so we only need to check the nvlink connectivity of these GPUs
num_dev = world_size
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(num_dev))
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
def begin_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = True
def end_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = False
def is_capturing() -> bool:
return _IS_CAPTURING and _CA_HANDLE is not None
def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
def is_initialized() -> bool:
return _CA_HANDLE is not None
@contextmanager
def capture():
try:
begin_capture()
yield
finally:
end_capture()
handle = get_handle()
if handle is not None:
handle.register_graph_buffers()
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_reg(input)
else:
if ca_handle.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
return None
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
@_nvml()
def _is_full_nvlink(device_ids: List[int]) -> bool:
@@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size
def __init__(self,
rank,
world_size,
full_nvlink,
group: Optional[ProcessGroup] = None,
device: Optional[Union[int, str, torch.device]] = None,
max_size=8192 * 1024) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if custom_ar is None:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
group = group or get_tensor_model_parallel_cpu_group()
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
rank = dist.get_rank(group=self.group)
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if device is None:
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = _is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly.")
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
dtype=torch.uint8,
device="cuda")
device=self.device)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
self.buffer = torch.empty(max_size,
dtype=torch.uint8,
device=self.device)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
@@ -211,8 +189,9 @@ class CustomAllreduce:
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device="cuda")
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = full_nvlink
@@ -221,6 +200,21 @@ class CustomAllreduce:
self.full_nvlink)
self.register_buffer(self.buffer)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
shard_data = (
@@ -230,14 +224,29 @@ class CustomAllreduce:
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data: List[Optional[Any]] = [[None]
for i in range(self.world_size)]
all_data[self.rank][0] = shard_data
ranks = dist.get_process_group_ranks(group=self.group)
ranks.sort()
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
handles.append(all_data[i][0][0]) # type: ignore
offsets.append(all_data[i][0][1]) # type: ignore
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
@@ -269,8 +278,31 @@ class CustomAllreduce:
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
# when custom allreduce is disabled, this will be None
if self.disabled:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
if self.should_custom_ar(input):
return self.all_reduce_reg(input)
else:
if self.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if self.should_custom_ar(input):
return self.all_reduce_unreg(input)
return None
def close(self):
if self._ptr:
if not self.disabled and self._ptr:
custom_ar.dispose(self._ptr)
self._ptr = 0

View File

@@ -96,8 +96,10 @@ class PyNcclCommunicator:
self.stream = torch.cuda.Stream()
# A small all_reduce for warmup.
self.all_reduce(torch.zeros(1, device=device))
data = torch.zeros(1, device=device)
self.all_reduce(data)
self.stream.synchronize()
del data
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually