[Core][Distributed] code deduplication in tp&pp with coordinator(#5293)

[Core][Distributed] add coordinator to reduce code duplication in tp and pp (#5293)
This commit is contained in:
youkaichao
2024-06-12 17:27:08 -07:00
committed by GitHub
parent 2135cacb45
commit ea3890a5f0
12 changed files with 625 additions and 585 deletions

View File

@@ -6,10 +6,11 @@ import torch
import torch.distributed
from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
get_world_group, graph_capture,
init_distributed_environment)
from vllm.utils import update_environment_variables
@@ -53,7 +54,8 @@ def worker_fn_wrapper(fn):
@worker_fn_wrapper
def worker_fn():
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
@@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
@@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
@worker_fn_wrapper
def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator()
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)