[core][distributed] initialization from StatelessProcessGroup (#10986)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-12 01:04:19 -08:00
committed by GitHub
parent 8195824206
commit 62de37a38e
5 changed files with 153 additions and 69 deletions

View File

@@ -5,7 +5,7 @@ import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
from unittest.mock import patch
import torch
@@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
@@ -476,13 +477,19 @@ class MessageQueue:
return self.dequeue()
@staticmethod
def create_from_process_group(pg: ProcessGroup,
def create_from_process_group(pg: Union[ProcessGroup,
StatelessProcessGroup],
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "MessageQueue":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
if isinstance(pg, ProcessGroup):
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
global_ranks = dist.get_process_group_ranks(pg)
else:
group_rank = pg.rank
group_world_size = pg.world_size
global_ranks = list(range(pg.world_size))
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
@@ -500,15 +507,21 @@ class MessageQueue:
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
if isinstance(pg, ProcessGroup):
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
else:
pg.broadcast_obj(handle, writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
if isinstance(pg, ProcessGroup):
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
handle = recv[0] # type: ignore
else:
handle = pg.broadcast_obj(None, writer_rank)
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io

View File

@@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op
@@ -1191,25 +1192,31 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
torch.cuda.empty_cache()
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
source_rank: int = 0) -> List[bool]:
"""
This is a collective operation that returns if each rank is in the same node
as the source rank. It tests if processes are attached to the same
memory system (shared access to shared memory).
"""
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
if isinstance(pg, ProcessGroup):
assert torch.distributed.get_backend(
pg) != torch.distributed.Backend.NCCL, (
"in_the_same_node_as should be tested with a non-NCCL group.")
# local rank inside the group
rank = torch.distributed.get_rank(group=pg)
world_size = torch.distributed.get_world_size(group=pg)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
else:
rank = pg.rank
world_size = pg.world_size
ranks = list(range(world_size))
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
# global ranks of the processes in the group
ranks = torch.distributed.get_process_group_ranks(pg)
magic_message = b"magic_message"
shm = None
@@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
# create a shared memory segment
shm = shared_memory.SharedMemory(create=True, size=128)
shm.buf[:len(magic_message)] = magic_message
torch.distributed.broadcast_object_list([shm.name],
src=ranks[source_rank],
group=pg)
if isinstance(pg, ProcessGroup):
torch.distributed.broadcast_object_list(
[shm.name], src=ranks[source_rank], group=pg)
else:
pg.broadcast_obj(shm.name, src=source_rank)
is_in_the_same_node[rank] = 1
else:
# try to open the shared memory segment
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=ranks[source_rank],
group=pg)
name = recv[0]
if isinstance(pg, ProcessGroup):
recv = [None]
torch.distributed.broadcast_object_list(
recv, src=ranks[source_rank], group=pg)
name = recv[0]
else:
name = pg.broadcast_obj(None, src=source_rank)
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
@@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
if shm:
shm.close()
torch.distributed.barrier(group=pg)
if isinstance(pg, ProcessGroup):
torch.distributed.barrier(group=pg)
else:
pg.barrier()
# clean up the shared memory segment
with contextlib.suppress(OSError):
if rank == source_rank and shm:
shm.unlink()
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
return [x == 1 for x in is_in_the_same_node.tolist()]
if isinstance(pg, ProcessGroup):
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
aggregated_data = is_in_the_same_node
else:
aggregated_data = torch.zeros_like(is_in_the_same_node)
for i in range(world_size):
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]