[core][distributed] add stateless_init_process_group (#10072)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-06 16:42:09 -08:00
committed by GitHub
parent 74f2f8a0f1
commit 719c1ca468
3 changed files with 147 additions and 3 deletions

View File

@@ -5,6 +5,11 @@
from typing import Sequence, Tuple
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
@@ -84,3 +89,71 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
end_layer = num_hidden_layers
return (start_layer, end_layer)
def stateless_init_process_group(init_method: str, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `ProcessGroup` object that can be used
for collective communication. With this function, process A and process B
can call `stateless_init_process_group` to form a group, and then process A, B,
C, and D can call `stateless_init_process_group` to form another group.
""" # noqa
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
pg_options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg