[Core][1/N] Support send/recv in PyNCCL Groups (#4988)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu
2024-05-23 09:54:48 -07:00
committed by GitHub
parent 2ba80bed27
commit 5eda2ea02a
5 changed files with 170 additions and 17 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.distributed import ProcessGroup
from .parallel_state import (get_cpu_world_group,
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -54,13 +54,19 @@ def graph_capture():
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm = get_tp_pynccl_communicator()
if pynccl_comm is None:
maybe_pynccl_context = nullcontext()
tp_pynccl_comm = get_tp_pynccl_communicator()
pp_pynccl_comm = get_pp_pynccl_communicator()
if not tp_pynccl_comm:
maybe_tp_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_pynccl_context:
if not pp_pynccl_comm:
maybe_pp_pynccl_context = nullcontext()
else:
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
yield graph_capture_context