[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_ca_communicator,
|
||||
get_tp_pynccl_communicator)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture_mode():
|
||||
# In graph capture, we have to be very careful about the collective
|
||||
def graph_mode():
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
@@ -24,10 +25,32 @@ def graph_capture_mode():
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the tensor size
|
||||
# is too large, it will fallback to the next available option.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
assert pynccl_comm is not None
|
||||
with pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream()):
|
||||
if pynccl_comm is None:
|
||||
context = nullcontext()
|
||||
else:
|
||||
context = pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream())
|
||||
with context:
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture():
|
||||
"""
|
||||
`graph_capture` is a context manager which should include the code that
|
||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed.
|
||||
"""
|
||||
ca_comm = get_tp_ca_communicator()
|
||||
context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||
with context:
|
||||
yield
|
||||
|
||||
|
||||
@@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
ca_comm = get_tp_ca_communicator()
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
if ca_comm is not None:
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
|
||||
Reference in New Issue
Block a user