[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)

This commit is contained in:
youkaichao
2024-05-12 17:47:59 -07:00
committed by GitHub
parent a7be4d0072
commit 702bee461f
10 changed files with 327 additions and 226 deletions

View File

@@ -1,5 +1,5 @@
from collections import namedtuple
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_ca_communicator,
get_tp_pynccl_communicator)
@contextmanager
def graph_capture_mode():
# In graph capture, we have to be very careful about the collective
def graph_mode():
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
@@ -24,10 +25,32 @@ def graph_capture_mode():
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
assert pynccl_comm is not None
with pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream()):
if pynccl_comm is None:
context = nullcontext()
else:
context = pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream())
with context:
yield
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should include the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
"""
ca_comm = get_tp_ca_communicator()
context = nullcontext() if ca_comm is None else ca_comm.capture()
with context:
yield
@@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
ca_comm = get_tp_ca_communicator()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
out = custom_all_reduce(input_)
if out is not None:
return out
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)