Implement custom all reduce kernels (#2192)

This commit is contained in:
Hanzhi Zhou
2024-01-28 04:46:35 +08:00
committed by GitHub
parent 220a47627b
commit 380170038e
18 changed files with 1453 additions and 65 deletions

View File

@@ -6,25 +6,13 @@ import pytest
import torch
import ray
from vllm.config import ParallelConfig
from vllm.utils import get_open_port
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather,
broadcast_tensor_dict,
)
from vllm.worker.worker import _init_distributed_environment
def init_test_distributed_environment(pipeline_parallel_size: int,
tensor_parallel_size: int, rank: int,
distributed_init_port: str):
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
@ray.remote(num_gpus=1, max_calls=1)
@@ -101,16 +89,4 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
ray.init()
distributed_init_port = get_open_port()
refs = []
for rank in range(tensor_parallel_size):
refs.append(
test_target.remote(tensor_parallel_size, rank,
distributed_init_port))
ray.get(refs)
ray.shutdown()
multi_process_tensor_parallel(tensor_parallel_size, test_target)