[Core][Distributed] refactor pynccl (#4591)

[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
This commit is contained in:
youkaichao
2024-05-09 19:48:43 -07:00
committed by GitHub
parent c833101740
commit 208b71bcc1
8 changed files with 466 additions and 433 deletions

View File

@@ -1,4 +1,5 @@
from collections import namedtuple
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@@ -8,7 +9,26 @@ from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
get_tp_pynccl_communicator)
@contextmanager
def graph_capture_mode():
# In graph capture, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
pynccl_comm = get_tp_pynccl_communicator()
assert pynccl_comm is not None
with pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream()):
yield
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
@@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
if is_pynccl_enabled_for_all_reduce():
pynccl_utils.all_reduce(input_)
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())