[core][distributed] add stateless process group (#10216)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import stateless_init_process_group
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import (cuda_device_count_stateless,
|
||||
update_environment_variables)
|
||||
|
||||
@@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
|
||||
|
||||
|
||||
def cpu_worker(rank, WORLD_SIZE):
|
||||
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
|
||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE,
|
||||
backend="gloo")
|
||||
world_size=WORLD_SIZE)
|
||||
if rank <= 2:
|
||||
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
|
||||
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
|
||||
rank=rank,
|
||||
world_size=3,
|
||||
backend="gloo")
|
||||
world_size=3)
|
||||
data = torch.tensor([rank])
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
|
||||
data = pg1.broadcast_obj(data, src=2)
|
||||
assert data.item() == 2
|
||||
if rank <= 2:
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
|
||||
item = data[0].item()
|
||||
print(f"rank: {rank}, item: {item}")
|
||||
if rank == 3:
|
||||
assert item == 6
|
||||
else:
|
||||
assert item == 18
|
||||
data = torch.tensor([rank + 1])
|
||||
data = pg2.broadcast_obj(data, src=2)
|
||||
assert data.item() == 3
|
||||
pg2.barrier()
|
||||
pg1.barrier()
|
||||
|
||||
|
||||
def gpu_worker(rank, WORLD_SIZE):
|
||||
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE,
|
||||
backend="nccl")
|
||||
if rank <= 2:
|
||||
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
|
||||
rank=rank,
|
||||
world_size=3,
|
||||
backend="nccl")
|
||||
torch.cuda.set_device(rank)
|
||||
data = torch.tensor([rank]).cuda()
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
|
||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||
pynccl1.disabled = False
|
||||
if rank <= 2:
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
|
||||
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||
pynccl2.disabled = False
|
||||
data = torch.tensor([rank]).cuda()
|
||||
pynccl1.all_reduce(data)
|
||||
pg1.barrier()
|
||||
torch.cuda.synchronize()
|
||||
if rank <= 2:
|
||||
pynccl2.all_reduce(data)
|
||||
pg2.barrier()
|
||||
torch.cuda.synchronize()
|
||||
item = data[0].item()
|
||||
print(f"rank: {rank}, item: {item}")
|
||||
if rank == 3:
|
||||
@@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
|
||||
assert item == 18
|
||||
|
||||
|
||||
def broadcast_worker(rank, WORLD_SIZE):
|
||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
if rank == 2:
|
||||
pg1.broadcast_obj("secret", src=2)
|
||||
else:
|
||||
obj = pg1.broadcast_obj(None, src=2)
|
||||
assert obj == "secret"
|
||||
pg1.barrier()
|
||||
|
||||
|
||||
def allgather_worker(rank, WORLD_SIZE):
|
||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
data = pg1.all_gather_obj(rank)
|
||||
assert data == list(range(WORLD_SIZE))
|
||||
pg1.barrier()
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
|
||||
def test_stateless_init_process_group(worker):
|
||||
@pytest.mark.parametrize(
|
||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
||||
def test_stateless_process_group(worker):
|
||||
WORLD_SIZE = 4
|
||||
from multiprocessing import get_context
|
||||
ctx = get_context("fork")
|
||||
|
||||
Reference in New Issue
Block a user