[Distributed] Add send and recv helpers (#5719)

This commit is contained in:
Murali Andoorveedu
2024-06-23 17:42:28 -04:00
committed by GitHub
parent 6c916ac8a8
commit 5d4d90536f
6 changed files with 278 additions and 24 deletions

View File

@@ -8,12 +8,11 @@ import pytest
import ray
import torch
from vllm.distributed import (broadcast_tensor_dict,
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import init_test_distributed_environment, multi_process_parallel
@ray.remote(num_gpus=1, max_calls=1)
@@ -105,6 +104,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}
if not get_pp_group().is_first_rank:
recv_dict = get_pp_group().recv_tensor_dict()
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(test_dict)
if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
size = 64
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
if not get_pp_group().is_first_rank:
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
if not get_pp_group().is_last_rank:
get_pp_group().send(test_tensor)
if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@@ -113,4 +174,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
multi_process_parallel(tp_size, 1, test_target)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)