[core][distributed] initialization from StatelessProcessGroup (#10986)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -15,6 +15,7 @@ from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
|
||||
|
||||
@@ -476,13 +477,19 @@ class MessageQueue:
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(pg: ProcessGroup,
|
||||
def create_from_process_group(pg: Union[ProcessGroup,
|
||||
StatelessProcessGroup],
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "MessageQueue":
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
else:
|
||||
group_rank = pg.rank
|
||||
group_world_size = pg.world_size
|
||||
global_ranks = list(range(pg.world_size))
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
@@ -500,15 +507,21 @@ class MessageQueue:
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
else:
|
||||
pg.broadcast_obj(handle, writer_rank)
|
||||
else:
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
handle = recv[0] # type: ignore
|
||||
if isinstance(pg, ProcessGroup):
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
handle = recv[0] # type: ignore
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
|
||||
@@ -37,6 +37,7 @@ from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, supports_custom_op
|
||||
@@ -1191,25 +1192,31 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
source_rank: int = 0) -> List[bool]:
|
||||
"""
|
||||
This is a collective operation that returns if each rank is in the same node
|
||||
as the source rank. It tests if processes are attached to the same
|
||||
memory system (shared access to shared memory).
|
||||
"""
|
||||
assert torch.distributed.get_backend(
|
||||
pg) != torch.distributed.Backend.NCCL, (
|
||||
"in_the_same_node_as should be tested with a non-NCCL group.")
|
||||
# local rank inside the group
|
||||
rank = torch.distributed.get_rank(group=pg)
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
if isinstance(pg, ProcessGroup):
|
||||
assert torch.distributed.get_backend(
|
||||
pg) != torch.distributed.Backend.NCCL, (
|
||||
"in_the_same_node_as should be tested with a non-NCCL group.")
|
||||
# local rank inside the group
|
||||
rank = torch.distributed.get_rank(group=pg)
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
|
||||
# global ranks of the processes in the group
|
||||
ranks = torch.distributed.get_process_group_ranks(pg)
|
||||
else:
|
||||
rank = pg.rank
|
||||
world_size = pg.world_size
|
||||
ranks = list(range(world_size))
|
||||
|
||||
# local tensor in each process to store the result
|
||||
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
|
||||
|
||||
# global ranks of the processes in the group
|
||||
ranks = torch.distributed.get_process_group_ranks(pg)
|
||||
|
||||
magic_message = b"magic_message"
|
||||
shm = None
|
||||
|
||||
@@ -1219,17 +1226,21 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
# create a shared memory segment
|
||||
shm = shared_memory.SharedMemory(create=True, size=128)
|
||||
shm.buf[:len(magic_message)] = magic_message
|
||||
torch.distributed.broadcast_object_list([shm.name],
|
||||
src=ranks[source_rank],
|
||||
group=pg)
|
||||
if isinstance(pg, ProcessGroup):
|
||||
torch.distributed.broadcast_object_list(
|
||||
[shm.name], src=ranks[source_rank], group=pg)
|
||||
else:
|
||||
pg.broadcast_obj(shm.name, src=source_rank)
|
||||
is_in_the_same_node[rank] = 1
|
||||
else:
|
||||
# try to open the shared memory segment
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(recv,
|
||||
src=ranks[source_rank],
|
||||
group=pg)
|
||||
name = recv[0]
|
||||
if isinstance(pg, ProcessGroup):
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(
|
||||
recv, src=ranks[source_rank], group=pg)
|
||||
name = recv[0]
|
||||
else:
|
||||
name = pg.broadcast_obj(None, src=source_rank)
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
@@ -1244,12 +1255,23 @@ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
|
||||
if shm:
|
||||
shm.close()
|
||||
|
||||
torch.distributed.barrier(group=pg)
|
||||
if isinstance(pg, ProcessGroup):
|
||||
torch.distributed.barrier(group=pg)
|
||||
else:
|
||||
pg.barrier()
|
||||
|
||||
# clean up the shared memory segment
|
||||
with contextlib.suppress(OSError):
|
||||
if rank == source_rank and shm:
|
||||
shm.unlink()
|
||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||
|
||||
return [x == 1 for x in is_in_the_same_node.tolist()]
|
||||
if isinstance(pg, ProcessGroup):
|
||||
torch.distributed.all_reduce(is_in_the_same_node, group=pg)
|
||||
aggregated_data = is_in_the_same_node
|
||||
else:
|
||||
aggregated_data = torch.zeros_like(is_in_the_same_node)
|
||||
for i in range(world_size):
|
||||
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
|
||||
aggregated_data += rank_data
|
||||
|
||||
return [x == 1 for x in aggregated_data.tolist()]
|
||||
|
||||
Reference in New Issue
Block a user