[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)

This commit is contained in:
youkaichao
2024-05-12 17:47:59 -07:00
committed by GitHub
parent a7be4d0072
commit 702bee461f
10 changed files with 327 additions and 226 deletions

View File

@@ -6,8 +6,10 @@ import ray
import torch
import torch.distributed as dist
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.distributed.communication_op import ( # noqa
graph_capture, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_ca_communicator)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
@@ -18,17 +20,36 @@ for i, v in enumerate(test_sizes):
@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(world_size, rank, distributed_init_port):
def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
custom_all_reduce.init_custom_ar()
group = get_tensor_model_parallel_group()
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_all_reduce.capture():
with graph_capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
@@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port):
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2)
for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(world_size, rank, distributed_init_port):
def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, world_size, rank,
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024
custom_all_reduce.init_custom_ar()
fa = custom_all_reduce.get_handle()
fa = get_tp_ca_communicator()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
multi_process_tensor_parallel(tensor_parallel_size, test_target)
if __name__ == "__main__":
multi_process_tensor_parallel(2, graph_allreduce)
def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)