[Core][Distributed] use cpu/gloo to initialize pynccl (#4248)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@@ -26,19 +27,23 @@ def distributed_run(fn, world_size):
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
def update_env(fn):
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapper(env):
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
init_distributed_environment()
|
||||
fn()
|
||||
|
||||
return wrapper
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
@update_env
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
comm = NCCLCommunicator()
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||
@@ -53,7 +58,7 @@ def test_pynccl():
|
||||
distributed_run(worker_fn, 2)
|
||||
|
||||
|
||||
@update_env
|
||||
@worker_fn_wrapper
|
||||
def worker_fn_with_cudagraph():
|
||||
with torch.no_grad():
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
Reference in New Issue
Block a user