[core][distributed] add stateless process group (#10216)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2,13 +2,13 @@
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
from typing import Sequence, Tuple
|
||||
import dataclasses
|
||||
import pickle
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
is_nccl_available)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||
return (start_layer, end_layer)
|
||||
|
||||
|
||||
def stateless_init_process_group(init_method: str, rank: int, world_size: int,
|
||||
backend: str) -> ProcessGroup:
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
@dataclasses.dataclass
|
||||
class StatelessProcessGroup:
|
||||
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
prefix: str
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
data_expiration_seconds: int = 3600 # 1 hour
|
||||
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
# dst rank -> counter
|
||||
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
# src rank -> counter
|
||||
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||
broadcast_send_counter: int = 0
|
||||
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
|
||||
default_factory=dict)
|
||||
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `ProcessGroup` object that can be used
|
||||
for collective communication. With this function, process A and process B
|
||||
can call `stateless_init_process_group` to form a group, and then process A, B,
|
||||
C, and D can call `stateless_init_process_group` to form another group.
|
||||
""" # noqa
|
||||
# A deque to store the data entries, with key and timestamp.
|
||||
entries: Deque[Tuple[str,
|
||||
float]] = dataclasses.field(default_factory=deque)
|
||||
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
def __post_init__(self):
|
||||
assert self.rank < self.world_size
|
||||
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||
self.broadcast_recv_src_counter = {
|
||||
i: 0
|
||||
for i in range(self.world_size)
|
||||
}
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
def expire_data(self):
|
||||
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||
while self.entries:
|
||||
# check the oldest entry
|
||||
key, timestamp = self.entries[0]
|
||||
if time.time() - timestamp > self.data_expiration_seconds:
|
||||
self.store.delete_key(key)
|
||||
self.entries.popleft()
|
||||
else:
|
||||
break
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
def recv_obj(self, src: int) -> Any:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(
|
||||
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
|
||||
))
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
|
||||
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
||||
"""Broadcast an object from a source rank to all other ranks.
|
||||
It does not clean up after all ranks have received the object.
|
||||
Use it for limited times, e.g., for initialization.
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = (f"{self.prefix}/broadcast_from/{src}/"
|
||||
f"{self.broadcast_send_counter}")
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = (f"{self.prefix}/broadcast_from/{src}/"
|
||||
f"{self.broadcast_recv_src_counter[src]}")
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return recv_obj
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
pg_options,
|
||||
)
|
||||
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||
"""All gather an object from all ranks."""
|
||||
gathered_objs = []
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
gathered_objs.append(obj)
|
||||
self.broadcast_obj(obj, src=self.rank)
|
||||
else:
|
||||
recv_obj = self.broadcast_obj(None, src=i)
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
if backend == "gloo":
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
backend_class = ProcessGroupGloo(prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
timeout=timeout)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
elif backend == "nccl":
|
||||
assert is_nccl_available()
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
def barrier(self):
|
||||
"""A barrier to synchronize all ranks."""
|
||||
for i in range(self.world_size):
|
||||
if i == self.rank:
|
||||
self.broadcast_obj(None, src=self.rank)
|
||||
else:
|
||||
self.broadcast_obj(None, src=i)
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
@staticmethod
|
||||
def create(
|
||||
init_method: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
If we have process A and process B called `torch.distributed.init_process_group`
|
||||
to form a group, and then we want to form another group with process A, B, C,
|
||||
D, it is not possible in PyTorch, because process A and process B have already
|
||||
formed a group, and process C and process D cannot join that group. This
|
||||
function is a workaround for this issue.
|
||||
|
||||
backend_class._set_sequence_number_for_group()
|
||||
`torch.distributed.init_process_group` is a global call, while this function
|
||||
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||
used for exchanging metadata. With this function, process A and process B
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
||||
timeout = _DEFAULT_PG_TIMEOUT
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
|
||||
return pg
|
||||
return StatelessProcessGroup(
|
||||
prefix=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
data_expiration_seconds=data_expiration_seconds)
|
||||
|
||||
Reference in New Issue
Block a user