[core][distributed] zmq fallback for broadcasting large objects (#6183)
[core][distributed] add zmq fallback for broadcasting large objects (#6183)
This commit is contained in:
@@ -9,7 +9,7 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
||||
|
||||
@@ -64,7 +64,7 @@ class CustomAllreduce:
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
if not is_in_the_same_node(group):
|
||||
if not all(in_the_same_node_as(group, source_rank=0)):
|
||||
# No need to initialize custom allreduce for multi-node case.
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because this process group"
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
@@ -135,18 +138,183 @@ class ShmRingBuffer:
|
||||
yield buf
|
||||
|
||||
|
||||
class ShmRingBufferIO:
|
||||
@dataclass
|
||||
class Handle:
|
||||
connect_ip: str
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
|
||||
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
|
||||
self.buffer = buffer
|
||||
self.reader_rank = reader_rank
|
||||
self._is_writer = self.reader_rank == -1
|
||||
self._is_reader = not self._is_writer
|
||||
if self._is_reader:
|
||||
assert 0 <= self.reader_rank < buffer.n_reader, \
|
||||
(f"Invalid reader rank {self.reader_rank} for buffer"
|
||||
f" created with {buffer.n_reader} readers")
|
||||
self.current_idx = 0
|
||||
buffer: Optional[ShmRingBuffer] = None
|
||||
local_subscribe_port: Optional[int] = None
|
||||
local_sync_port: Optional[int] = None
|
||||
remote_subscribe_port: Optional[int] = None
|
||||
remote_sync_port: Optional[int] = None
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_reader, # number of all readers
|
||||
n_local_reader, # number of local readers through shared memory
|
||||
local_reader_ranks: Optional[List[int]] = None,
|
||||
max_chunk_bytes: int = 1024 * 1024 * 10,
|
||||
max_chunks: int = 10,
|
||||
connect_ip: Optional[str] = None,
|
||||
):
|
||||
if local_reader_ranks is None:
|
||||
local_reader_ranks = list(range(n_local_reader))
|
||||
else:
|
||||
assert len(local_reader_ranks) == n_local_reader
|
||||
self.n_local_reader = n_local_reader
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
if connect_ip is None:
|
||||
connect_ip = get_ip()
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
# for local readers, we will:
|
||||
# 1. create a shared memory ring buffer to communicate small data
|
||||
# 2. create a publish-subscribe socket to communicate large data
|
||||
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
|
||||
max_chunks)
|
||||
|
||||
self.local_socket = context.socket(PUB)
|
||||
local_subscribe_port = get_open_port()
|
||||
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
|
||||
|
||||
self.local_sync_socket = context.socket(REP)
|
||||
local_sync_port = get_open_port()
|
||||
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
|
||||
self.current_idx = 0
|
||||
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_port = None
|
||||
local_sync_port = None
|
||||
self.local_socket = None
|
||||
self.local_sync_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
self.remote_socket = context.socket(PUB)
|
||||
remote_subscribe_port = get_open_port()
|
||||
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
|
||||
|
||||
self.remote_sync_socket = context.socket(REP)
|
||||
remote_sync_port = get_open_port()
|
||||
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
|
||||
else:
|
||||
remote_subscribe_port = None
|
||||
remote_sync_port = None
|
||||
self.remote_socket = None
|
||||
self.remote_sync_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
self._is_local_reader = False
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.handle = Handle(
|
||||
connect_ip=connect_ip,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer=self.buffer,
|
||||
local_subscribe_port=local_subscribe_port,
|
||||
local_sync_port=local_sync_port,
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
remote_sync_port=remote_sync_port,
|
||||
)
|
||||
|
||||
def export_handle(self) -> Handle:
|
||||
return self.handle
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(handle: Handle, rank) -> "MessageQueue":
|
||||
self = MessageQueue.__new__(MessageQueue)
|
||||
self.handle = handle
|
||||
self._is_writer = False
|
||||
|
||||
context = Context()
|
||||
|
||||
if rank in handle.local_reader_ranks:
|
||||
assert handle.buffer is not None
|
||||
self.buffer = handle.buffer
|
||||
self.current_idx = 0
|
||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||
self._is_local_reader = True
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
self.local_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
|
||||
|
||||
self.local_sync_socket = context.socket(REQ)
|
||||
self.local_sync_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
|
||||
|
||||
self.remote_socket = None
|
||||
self.remote_sync_socket = None
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
self.local_reader_rank = -1
|
||||
self._is_local_reader = False
|
||||
self._is_remote_reader = True
|
||||
|
||||
self.local_socket = None
|
||||
self.local_sync_socket = None
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
self.remote_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
|
||||
|
||||
self.remote_sync_socket = context.socket(REQ)
|
||||
self.remote_sync_socket.connect(
|
||||
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
|
||||
|
||||
return self
|
||||
|
||||
def wait_until_ready(self):
|
||||
"""This is a collective operation. All processes (including the
|
||||
readers and the writer) should call this function.
|
||||
"""
|
||||
if self._is_writer:
|
||||
# wait for all readers to connect
|
||||
|
||||
# local readers
|
||||
for i in range(self.n_local_reader):
|
||||
recv = self.local_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
self.local_sync_socket.send(b"READY")
|
||||
if self.n_local_reader > 0:
|
||||
self.local_socket.send(b"READY")
|
||||
|
||||
# remote readers
|
||||
for i in range(self.n_remote_reader):
|
||||
recv = self.remote_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
self.remote_sync_socket.send(b"READY")
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(b"READY")
|
||||
elif self._is_local_reader:
|
||||
self.local_sync_socket.send(b"READY")
|
||||
recv = self.local_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
recv = self.local_socket.recv()
|
||||
assert recv == b"READY"
|
||||
elif self._is_remote_reader:
|
||||
self.remote_sync_socket.send(b"READY")
|
||||
recv = self.remote_sync_socket.recv()
|
||||
assert recv == b"READY"
|
||||
recv = self.remote_socket.recv()
|
||||
assert recv == b"READY"
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self):
|
||||
@@ -201,12 +369,12 @@ class ShmRingBufferIO:
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self):
|
||||
assert self._is_reader, "Only readers can acquire read"
|
||||
assert self._is_local_reader, "Only readers can acquire read"
|
||||
start_time = time.monotonic()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_flag = metadata_buffer[self.reader_rank + 1]
|
||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
@@ -236,7 +404,7 @@ class ShmRingBufferIO:
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.reader_rank + 1] = 1
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
break
|
||||
@@ -244,21 +412,36 @@ class ShmRingBufferIO:
|
||||
def enqueue(self, obj):
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if len(serialized_obj) > self.buffer.max_chunk_bytes:
|
||||
raise RuntimeError(
|
||||
f"{len(serialized_obj)=} larger than the allowed value "
|
||||
f"{self.buffer.max_chunk_bytes},"
|
||||
"Please increase the max_chunk_bytes parameter.")
|
||||
with self.acquire_write() as buf:
|
||||
buf[:len(serialized_obj)] = serialized_obj
|
||||
if self.n_local_reader > 0:
|
||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||
with self.acquire_write() as buf:
|
||||
buf[0] = 1 # overflow
|
||||
self.local_socket.send(serialized_obj)
|
||||
else:
|
||||
with self.acquire_write() as buf:
|
||||
buf[0] = 0 # not overflow
|
||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||
if self.n_remote_reader > 0:
|
||||
self.remote_socket.send(serialized_obj)
|
||||
|
||||
def dequeue(self):
|
||||
assert self._is_reader, "Only readers can dequeue"
|
||||
with self.acquire_read() as buf:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format itself contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf)
|
||||
if self._is_local_reader:
|
||||
overflow = False
|
||||
with self.acquire_read() as buf:
|
||||
overflow = buf[0] == 1
|
||||
if not overflow:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf[1:])
|
||||
if overflow:
|
||||
recv = self.local_socket.recv()
|
||||
obj = pickle.loads(recv)
|
||||
elif self._is_remote_reader:
|
||||
recv = self.remote_socket.recv()
|
||||
obj = pickle.loads(recv)
|
||||
else:
|
||||
raise RuntimeError("Only readers can dequeue")
|
||||
return obj
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
@@ -272,24 +455,36 @@ class ShmRingBufferIO:
|
||||
def create_from_process_group(pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "ShmRingBufferIO":
|
||||
writer_rank=0) -> "MessageQueue":
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
ranks_inside_group = list(range(group_world_size))
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
buffer: ShmRingBuffer
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
|
||||
dist.broadcast_object_list([buffer],
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
dist.broadcast_object_list([handle],
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
return ShmRingBufferIO(buffer, -1)
|
||||
else:
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv,
|
||||
src=global_ranks[writer_rank],
|
||||
group=pg)
|
||||
buffer = recv[0] # type: ignore
|
||||
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
|
||||
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
|
||||
handle = recv[0] # type: ignore
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
|
||||
Reference in New Issue
Block a user