[Core][Distributed] refactor pynccl (#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -8,7 +9,26 @@ from .parallel_state import (get_cpu_world_group,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
is_pynccl_enabled_for_all_reduce)
|
||||
get_tp_pynccl_communicator)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture_mode():
|
||||
# In graph capture, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the tensor size
|
||||
# is too large, it will fallback to the next available option.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
assert pynccl_comm is not None
|
||||
with pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream()):
|
||||
yield
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
@@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
|
||||
@@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
if is_pynccl_enabled_for_all_reduce():
|
||||
pynccl_utils.all_reduce(input_)
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
||||
Reference in New Issue
Block a user