[PERF] PyTorch Symmetric Memory All-Reduce (#20759)
Signed-off-by: ilmarkov <imarkov@redhat.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: ilmarkov <imarkov@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: MiB // 2, # 512 KB
|
||||
8: MiB // 4, # 256 KB
|
||||
},
|
||||
"10.0": {
|
||||
2: 2 * MiB, # 2 MB
|
||||
4: 2 * MiB, # 2 MB
|
||||
6: 2 * MiB, # 2 MB
|
||||
8: 2 * MiB, # 2 MB
|
||||
}
|
||||
}
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 64 * MiB, # 64 MB
|
||||
8: 64 * MiB, # 64 MB
|
||||
},
|
||||
"10.0": {
|
||||
2: 8 * MiB, # 8 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
@@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
PyNcclCommunicator)
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce)
|
||||
from vllm.distributed.device_communicators.symm_mem import (
|
||||
SymmMemCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
@@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
self.qr_comm: Optional[QuickAllReduce] = None
|
||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
@@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
# currently be an MI300 series.
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group,
|
||||
device=self.device)
|
||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
||||
self.symm_mem_comm = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
@@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
symm_mem_comm = self.symm_mem_comm
|
||||
if symm_mem_comm is not None and \
|
||||
symm_mem_comm.should_use_symm_mem(input_):
|
||||
out = symm_mem_comm.all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
|
||||
@@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -109,7 +109,13 @@ class CustomAllreduce:
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
|
||||
max_size = min(
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
|
||||
max_size)
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
|
||||
111
vllm/distributed/device_communicators/symm_mem.py
Normal file
111
vllm/distributed/device_communicators/symm_mem.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
import torch.distributed._symmetric_memory as torch_symm_mem
|
||||
|
||||
symm_mem_available = True
|
||||
except ImportError:
|
||||
symm_mem_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SymmMemCommunicator:
|
||||
_WORLD_SIZES_MULTIMEM = {
|
||||
"9.0": [4, 6, 8],
|
||||
"10.0": [6, 8],
|
||||
}
|
||||
|
||||
def __init__(self, group: ProcessGroup, device: Union[int, str,
|
||||
torch.device]):
|
||||
self.disabled = True
|
||||
|
||||
if not symm_mem_available:
|
||||
return
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
logger.warning("SymmMemCommunicator: symmetric "
|
||||
"memory is not available.")
|
||||
return
|
||||
if isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
torch.cuda.set_device(device)
|
||||
self.dtype = torch.bfloat16
|
||||
self.device = device
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
self.device_capability = current_platform.get_device_capability(
|
||||
).as_version_str()
|
||||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: Device capability %s not supported, "
|
||||
"communicator is not available.",
|
||||
self.device_capability,
|
||||
)
|
||||
return
|
||||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
|
||||
self.device_capability]:
|
||||
logger.warning(
|
||||
"SymmMemCommunicator: World size %d not supported, "
|
||||
"communicator is not available.",
|
||||
self.world_size,
|
||||
)
|
||||
return
|
||||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
||||
self.world_size]
|
||||
self.buffer = torch_symm_mem.empty(
|
||||
self.max_size // self.dtype.itemsize,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
||||
if handle.multicast_ptr == 0:
|
||||
logger.warning("SymmMemCommunicator: symmetric memory "
|
||||
"multicast operations are not supported.")
|
||||
return
|
||||
self.disabled = False
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
if inp.dtype != self.dtype:
|
||||
return False
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
if inp_size % 4 != 0:
|
||||
return False
|
||||
return inp_size < self.max_size
|
||||
|
||||
def all_reduce(
|
||||
self,
|
||||
inp: torch.Tensor,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
|
||||
if not self.should_use_symm_mem(inp):
|
||||
return None
|
||||
if out is None:
|
||||
out = torch.empty_like(inp)
|
||||
self.buffer[:inp.numel()].copy_(inp.view(-1))
|
||||
if self.world_size in self._WORLD_SIZES_MULTIMEM[
|
||||
self.device_capability]:
|
||||
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
else:
|
||||
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
|
||||
"sum",
|
||||
self.group.group_name)
|
||||
out.copy_(self.buffer[:inp.numel()].view(out.shape))
|
||||
return out
|
||||
Reference in New Issue
Block a user