[core][distributed] use tcp store directly (#10275)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-12 17:36:08 -08:00
committed by GitHub
parent 112fa0bbe5
commit 0d4ea3fb5c
2 changed files with 29 additions and 25 deletions

View File

@@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
@@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2:
pg2 = StatelessProcessGroup.create(
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
@@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
@@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
@@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1.barrier()
# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])