Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Markov Ilya <markovilya19@gmail.com> Co-authored-by: Markov Ilya <markovilya19@gmail.com>
92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import atexit
|
|
import os
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
|
|
from vllm.config import VllmConfig, set_current_vllm_config
|
|
from vllm.distributed.parallel_state import (
|
|
init_distributed_environment,
|
|
)
|
|
from vllm.utils.system_utils import update_environment_variables
|
|
|
|
mp.set_start_method("spawn", force=True)
|
|
|
|
|
|
def _distributed_worker_wrapper(fn, env, world_size, args, rank, skip_queue):
|
|
try:
|
|
fn(env, world_size, *args)
|
|
except BaseException as exc:
|
|
if isinstance(exc, pytest.skip.Exception):
|
|
skip_queue.put((rank, str(exc)))
|
|
return
|
|
raise
|
|
|
|
|
|
def distributed_run(fn, world_size, *args):
|
|
number_of_processes = world_size
|
|
processes: list[mp.Process] = []
|
|
skip_queue: mp.SimpleQueue = mp.SimpleQueue()
|
|
for i in range(number_of_processes):
|
|
env: dict[str, str] = {}
|
|
env["RANK"] = str(i)
|
|
env["LOCAL_RANK"] = str(i)
|
|
env["WORLD_SIZE"] = str(number_of_processes)
|
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
|
env["MASTER_ADDR"] = "localhost"
|
|
env["MASTER_PORT"] = "12345"
|
|
p = mp.Process(
|
|
target=_distributed_worker_wrapper,
|
|
args=(fn, env, world_size, args, i, skip_queue),
|
|
)
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
skipped: list[tuple[int, str]] = []
|
|
while not skip_queue.empty():
|
|
rank, reason = skip_queue.get()
|
|
skipped.append((rank, reason))
|
|
|
|
if len(skipped) == number_of_processes:
|
|
reason = skipped[0][1]
|
|
pytest.skip(reason)
|
|
if 0 < len(skipped) < number_of_processes:
|
|
skipped_ranks = sorted(rank for rank, _ in skipped)
|
|
raise AssertionError(
|
|
"Distributed test had partial skips; expected either all ranks "
|
|
f"to skip or none. Skipped ranks: {skipped_ranks}, "
|
|
f"total ranks: {number_of_processes}"
|
|
)
|
|
|
|
for p in processes:
|
|
assert p.exitcode == 0
|
|
|
|
|
|
def set_env_vars_and_device(env: dict[str, str]) -> None:
|
|
update_environment_variables(env)
|
|
local_rank = os.environ["LOCAL_RANK"]
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
torch.accelerator.set_device_index(device)
|
|
|
|
# Create a minimal vllm config for init_distributed_environment
|
|
vllm_config = VllmConfig()
|
|
with set_current_vllm_config(vllm_config):
|
|
init_distributed_environment()
|
|
atexit.register(_destroy_process_group_if_initialized)
|
|
# Ensure each worker process has the same random seed
|
|
random.seed(42)
|
|
torch.manual_seed(42)
|
|
|
|
|
|
def _destroy_process_group_if_initialized() -> None:
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
torch.distributed.destroy_process_group()
|