[Core][Refactor] move parallel_utils into vllm/distributed (#3950)

[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
This commit is contained in:
youkaichao
2024-04-10 15:33:30 -07:00
committed by GitHub
parent 934d3662f7
commit 63e7176f26
52 changed files with 111 additions and 141 deletions

View File

@@ -6,9 +6,8 @@ import ray
import torch
import torch.distributed as dist
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators import custom_all_reduce
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
@@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank,
distributed_init_port)
custom_ar.init_custom_ar()
custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_ar.capture():
with custom_all_reduce.capture():
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
@@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port)
sz = 1024
custom_ar.init_custom_ar()
fa = custom_ar.get_handle()
custom_all_reduce.init_custom_all_reduce()
fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size)