[core][distributed] add stateless process group (#10216)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-11 09:02:14 -08:00
committed by GitHub
parent 36fc439de0
commit e6de9784d2
3 changed files with 206 additions and 101 deletions

View File

@@ -1,10 +1,10 @@
import pytest
import ray
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (cuda_device_count_stateless,
update_environment_variables)
@@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=WORLD_SIZE,
backend="gloo")
world_size=WORLD_SIZE)
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
rank=rank,
world_size=3,
backend="gloo")
world_size=3)
data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
assert item == 6
else:
assert item == 18
data = torch.tensor([rank + 1])
data = pg2.broadcast_obj(data, src=2)
assert data.item() == 3
pg2.barrier()
pg1.barrier()
def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank)
data = torch.tensor([rank]).cuda()
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()
torch.cuda.synchronize()
if rank <= 2:
pynccl2.all_reduce(data)
pg2.barrier()
torch.cuda.synchronize()
item = data[0].item()
print(f"rank: {rank}, item: {item}")
if rank == 3:
@@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
assert item == 18
def broadcast_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
pg1.broadcast_obj("secret", src=2)
else:
obj = pg1.broadcast_obj(None, src=2)
assert obj == "secret"
pg1.barrier()
def allgather_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
def test_stateless_init_process_group(worker):
@pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
WORLD_SIZE = 4
from multiprocessing import get_context
ctx = get_context("fork")