[EPLB] Add alternative communication for EPLB weight exchange (#33176)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Markov Ilya <markovilya19@gmail.com> Co-authored-by: Markov Ilya <markovilya19@gmail.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
@@ -16,9 +18,20 @@ from vllm.utils.system_utils import update_environment_variables
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _distributed_worker_wrapper(fn, env, world_size, args, rank, skip_queue):
|
||||
try:
|
||||
fn(env, world_size, *args)
|
||||
except BaseException as exc:
|
||||
if isinstance(exc, pytest.skip.Exception):
|
||||
skip_queue.put((rank, str(exc)))
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
def distributed_run(fn, world_size, *args):
|
||||
number_of_processes = world_size
|
||||
processes: list[mp.Process] = []
|
||||
skip_queue: mp.SimpleQueue = mp.SimpleQueue()
|
||||
for i in range(number_of_processes):
|
||||
env: dict[str, str] = {}
|
||||
env["RANK"] = str(i)
|
||||
@@ -27,13 +40,32 @@ def distributed_run(fn, world_size, *args):
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = mp.Process(target=fn, args=(env, world_size, *args))
|
||||
p = mp.Process(
|
||||
target=_distributed_worker_wrapper,
|
||||
args=(fn, env, world_size, args, i, skip_queue),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
skipped: list[tuple[int, str]] = []
|
||||
while not skip_queue.empty():
|
||||
rank, reason = skip_queue.get()
|
||||
skipped.append((rank, reason))
|
||||
|
||||
if len(skipped) == number_of_processes:
|
||||
reason = skipped[0][1]
|
||||
pytest.skip(reason)
|
||||
if 0 < len(skipped) < number_of_processes:
|
||||
skipped_ranks = sorted(rank for rank, _ in skipped)
|
||||
raise AssertionError(
|
||||
"Distributed test had partial skips; expected either all ranks "
|
||||
f"to skip or none. Skipped ranks: {skipped_ranks}, "
|
||||
f"total ranks: {number_of_processes}"
|
||||
)
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
@@ -48,7 +80,12 @@ def set_env_vars_and_device(env: dict[str, str]) -> None:
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment()
|
||||
|
||||
atexit.register(_destroy_process_group_if_initialized)
|
||||
# Ensure each worker process has the same random seed
|
||||
random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _destroy_process_group_if_initialized() -> None:
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
Reference in New Issue
Block a user