[core][distributed] zmq fallback for broadcasting large objects (#6183)

[core][distributed] add zmq fallback for broadcasting large objects (#6183)
This commit is contained in:
youkaichao
2024-07-09 18:49:11 -07:00
committed by GitHub
parent 2416b26e11
commit da78caecfa
6 changed files with 274 additions and 80 deletions

View File

@@ -9,7 +9,7 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
@@ -64,7 +64,7 @@ class CustomAllreduce:
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not is_in_the_same_node(group):
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"

View File

@@ -1,16 +1,19 @@
import pickle
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
from typing import Optional
from typing import List, Optional
from unittest.mock import patch
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
@@ -135,18 +138,183 @@ class ShmRingBuffer:
yield buf
class ShmRingBufferIO:
@dataclass
class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
self.buffer = buffer
self.reader_rank = reader_rank
self._is_writer = self.reader_rank == -1
self._is_reader = not self._is_writer
if self._is_reader:
assert 0 <= self.reader_rank < buffer.n_reader, \
(f"Invalid reader rank {self.reader_rank} for buffer"
f" created with {buffer.n_reader} readers")
self.current_idx = 0
buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None
class MessageQueue:
def __init__(
self,
n_reader, # number of all readers
n_local_reader, # number of local readers through shared memory
local_reader_ranks: Optional[List[int]] = None,
max_chunk_bytes: int = 1024 * 1024 * 10,
max_chunks: int = 10,
connect_ip: Optional[str] = None,
):
if local_reader_ranks is None:
local_reader_ranks = list(range(n_local_reader))
else:
assert len(local_reader_ranks) == n_local_reader
self.n_local_reader = n_local_reader
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
if connect_ip is None:
connect_ip = get_ip()
context = Context()
if n_local_reader > 0:
# for local readers, we will:
# 1. create a shared memory ring buffer to communicate small data
# 2. create a publish-subscribe socket to communicate large data
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)
self.local_socket = context.socket(PUB)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None
self._is_writer = True
self._is_local_reader = False
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)
def export_handle(self) -> Handle:
return self.handle
@staticmethod
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self = MessageQueue.__new__(MessageQueue)
self.handle = handle
self._is_writer = False
context = Context()
if rank in handle.local_reader_ranks:
assert handle.buffer is not None
self.buffer = handle.buffer
self.current_idx = 0
self.local_reader_rank = handle.local_reader_ranks.index(rank)
self._is_local_reader = True
self._is_remote_reader = False
self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "")
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
self.local_reader_rank = -1
self._is_local_reader = False
self._is_remote_reader = True
self.local_socket = None
self.local_sync_socket = None
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
return self
def wait_until_ready(self):
"""This is a collective operation. All processes (including the
readers and the writer) should call this function.
"""
if self._is_writer:
# wait for all readers to connect
# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
if self.n_local_reader > 0:
self.local_socket.send(b"READY")
# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
if self.n_remote_reader > 0:
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
recv = self.remote_socket.recv()
assert recv == b"READY"
@contextmanager
def acquire_write(self):
@@ -201,12 +369,12 @@ class ShmRingBufferIO:
@contextmanager
def acquire_read(self):
assert self._is_reader, "Only readers can acquire read"
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.reader_rank + 1]
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
# this block is either
@@ -236,7 +404,7 @@ class ShmRingBufferIO:
# caller has read from the buffer
# set the read flag
metadata_buffer[self.reader_rank + 1] = 1
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break
@@ -244,21 +412,36 @@ class ShmRingBufferIO:
def enqueue(self, obj):
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if len(serialized_obj) > self.buffer.max_chunk_bytes:
raise RuntimeError(
f"{len(serialized_obj)=} larger than the allowed value "
f"{self.buffer.max_chunk_bytes},"
"Please increase the max_chunk_bytes parameter.")
with self.acquire_write() as buf:
buf[:len(serialized_obj)] = serialized_obj
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
def dequeue(self):
assert self._is_reader, "Only readers can dequeue"
with self.acquire_read() as buf:
# no need to know the size of serialized object
# pickle format itself contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf)
if self._is_local_reader:
overflow = False
with self.acquire_read() as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
# pickle format contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf[1:])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
else:
raise RuntimeError("Only readers can dequeue")
return obj
def broadcast_object(self, obj=None):
@@ -272,24 +455,36 @@ class ShmRingBufferIO:
def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "ShmRingBufferIO":
writer_rank=0) -> "MessageQueue":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
ranks_inside_group = list(range(group_world_size))
global_ranks = dist.get_process_group_ranks(pg)
from vllm.distributed.parallel_state import in_the_same_node_as
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
buffer: ShmRingBuffer
n_local_reader = len(same_node_ranks) - 1
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
buffer_io: MessageQueue
if group_rank == writer_rank:
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
dist.broadcast_object_list([buffer],
buffer_io = MessageQueue(
n_reader=n_reader,
n_local_reader=n_local_reader,
local_reader_ranks=local_reader_ranks,
max_chunk_bytes=max_chunk_bytes,
max_chunks=max_chunks,
)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle],
src=global_ranks[writer_rank],
group=pg)
return ShmRingBufferIO(buffer, -1)
else:
recv = [None]
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
buffer = recv[0] # type: ignore
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
handle = recv[0] # type: ignore
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
buffer_io.wait_until_ready()
return buffer_io