[Core][1/N] Support send/recv in PyNCCL Groups (#4988)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
committed by
GitHub
parent
2ba80bed27
commit
5eda2ea02a
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .parallel_state import (get_cpu_world_group,
|
||||
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -54,13 +54,19 @@ def graph_capture():
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if pynccl_comm is None:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
tp_pynccl_comm = get_tp_pynccl_communicator()
|
||||
pp_pynccl_comm = get_pp_pynccl_communicator()
|
||||
if not tp_pynccl_comm:
|
||||
maybe_tp_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream())
|
||||
with maybe_pynccl_context:
|
||||
if not pp_pynccl_comm:
|
||||
maybe_pp_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream())
|
||||
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user