[Core][1/N] Support send/recv in PyNCCL Groups (#4988)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu
2024-05-23 09:54:48 -07:00
committed by GitHub
parent 2ba80bed27
commit 5eda2ea02a
5 changed files with 170 additions and 17 deletions

View File

@@ -3,6 +3,7 @@ import os
import pytest
import torch
import torch.distributed
from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
@@ -68,7 +69,7 @@ def test_pynccl():
@worker_fn_wrapper
def multiple_tp_worker_fn():
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
@@ -92,14 +93,14 @@ def multiple_tp_worker_fn():
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp():
def test_pynccl_multiple_allreduce():
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `pynccl_comm.all_reduce` directly
distributed_run(multiple_tp_worker_fn, 4)
distributed_run(multiple_allreduce_worker_fn, 4)
@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
@@ -118,10 +119,10 @@ def multiple_tp_with_vllm_worker_fn():
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp_with_vllm():
def test_pynccl_multiple_allreduce_with_vllm():
# this tests pynccl for multiple tp groups, together with vllm
# i.e. call `tensor_model_parallel_all_reduce`
distributed_run(multiple_tp_with_vllm_worker_fn, 4)
distributed_run(multiple_allreduce_with_vllm_worker_fn, 4)
@worker_fn_wrapper
@@ -151,6 +152,68 @@ def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)
@worker_fn_wrapper
def send_recv_worker_fn():
pynccl_comm = PyNcclCommunicator()
if pynccl_comm.rank == 0:
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
else:
tensor = torch.empty(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0:
pynccl_comm.send(tensor)
else:
pynccl_comm.recv(tensor)
result = tensor.mean().cpu().item()
assert result == 1
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_send_recv():
distributed_run(send_recv_worker_fn, 2)
@worker_fn_wrapper
def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo")
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
elif torch.distributed.get_rank() == 1:
tensor = 2 * torch.ones(
16, 1024, 1024, dtype=torch.float32, device=device)
else:
tensor = torch.empty(16,
1024,
1024,
dtype=torch.float32,
device=device)
with pynccl_comm.change_state(enable=True):
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.send(tensor)
else:
pynccl_comm.recv(tensor)
result = tensor.mean().cpu().item()
if torch.distributed.get_rank() in [0, 2]:
assert result == 1
else:
assert result == 2
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_send_recv():
distributed_run(multiple_send_recv_worker_fn, 4)
def test_ncclGetUniqueId():
lib = NCCLLibrary()
unique_id = lib.ncclGetUniqueId()