[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -19,6 +19,8 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata
|
||||
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
|
||||
|
||||
from ..utils import (
|
||||
init_test_distributed_environment,
|
||||
@@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker(
|
||||
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
class _DummyWork:
|
||||
def __init__(self) -> None:
|
||||
self.wait_calls = 0
|
||||
|
||||
def wait(self) -> None:
|
||||
self.wait_calls += 1
|
||||
|
||||
|
||||
class _DummyAllGatherGroup:
|
||||
def __init__(self, world_size: int, rank_in_group: int) -> None:
|
||||
self.world_size = world_size
|
||||
self.rank_in_group = rank_in_group
|
||||
|
||||
def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||
# duplicate local slice across ranks.
|
||||
assert dim == 0
|
||||
return torch.cat([t for _ in range(self.world_size)], dim=0)
|
||||
|
||||
|
||||
def _make_group_for_unit_test(
|
||||
rank_in_group: int = 0, world_size: int = 2
|
||||
) -> GroupCoordinator:
|
||||
# avoid running GroupCoordinator.__init__ (it wires up real process groups).
|
||||
g = GroupCoordinator.__new__(GroupCoordinator)
|
||||
g.world_size = world_size
|
||||
g.rank_in_group = rank_in_group
|
||||
g.ranks = list(range(world_size))
|
||||
g.use_cpu_custom_send_recv = False
|
||||
g.device_group = None
|
||||
g.cpu_group = None
|
||||
return g
|
||||
|
||||
|
||||
def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork:
|
||||
t.fill_(1)
|
||||
return _DummyWork()
|
||||
|
||||
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
|
||||
monkeypatch.setattr(torch.distributed, "irecv", fake_irecv)
|
||||
|
||||
g = _make_group_for_unit_test(rank_in_group=0, world_size=2)
|
||||
# 2 tensors so we can catch late-binding bugs in postprocess closures.
|
||||
metadata_list = [
|
||||
("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
|
||||
("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
|
||||
]
|
||||
g.recv_object = lambda src=None: metadata_list # type: ignore[method-assign]
|
||||
|
||||
ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0)
|
||||
td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag)
|
||||
|
||||
assert td is not None
|
||||
assert len(handles) == 2
|
||||
assert len(postprocess) == 2
|
||||
|
||||
# before postprocess, dict holds the TP slice (shape 2).
|
||||
assert td["a"].shape == torch.Size([2])
|
||||
assert td["b"].shape == torch.Size([2])
|
||||
|
||||
# simulate worker-side "defer wait": wait + postprocess later.
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
for fn in postprocess:
|
||||
fn()
|
||||
|
||||
# after postprocess, dict values are reconstructed to full shape (shape 4),
|
||||
# and each key should be updated independently
|
||||
assert td["a"].shape == torch.Size([4])
|
||||
assert td["b"].shape == torch.Size([4])
|
||||
torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32))
|
||||
torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32))
|
||||
|
||||
|
||||
def test_async_intermediate_tensors_lazy_wait() -> None:
|
||||
work = _DummyWork()
|
||||
post_calls = {"n": 0}
|
||||
|
||||
def post() -> None:
|
||||
post_calls["n"] += 1
|
||||
|
||||
it = AsyncIntermediateTensors(
|
||||
{"x": torch.tensor([1])},
|
||||
comm_handles=[work],
|
||||
comm_postprocess=[post],
|
||||
)
|
||||
|
||||
# accessing non-tensor attributes should not trigger wait.
|
||||
assert it.kv_connector_output is None
|
||||
assert work.wait_calls == 0
|
||||
assert post_calls["n"] == 0
|
||||
|
||||
# first access of `.tensors` triggers wait + postprocess.
|
||||
_ = it.tensors
|
||||
assert work.wait_calls == 1
|
||||
assert post_calls["n"] == 1
|
||||
|
||||
# subsequent access should not re-wait.
|
||||
_ = it.tensors
|
||||
assert work.wait_calls == 1
|
||||
assert post_calls["n"] == 1
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any
|
||||
from typing import Any, Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -64,6 +64,14 @@ class GraphCaptureContext:
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
|
||||
class Handle(Protocol):
|
||||
"""Minimal async work handle used by P2P send/recv methods."""
|
||||
|
||||
def is_completed(self) -> bool: ...
|
||||
|
||||
def wait(self) -> None: ...
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
|
||||
@@ -780,6 +788,20 @@ class GroupCoordinator:
|
||||
async_handle.wait()
|
||||
return tensor_dict
|
||||
|
||||
def _should_use_all_gather(
|
||||
self,
|
||||
key: str,
|
||||
numel: int,
|
||||
all_gather_group: "GroupCoordinator | None",
|
||||
all_gather_tensors: dict[str, bool] | None,
|
||||
) -> bool:
|
||||
if all_gather_group is None:
|
||||
return False
|
||||
use_all_gather = numel % all_gather_group.world_size == 0
|
||||
if all_gather_tensors is not None:
|
||||
use_all_gather = all_gather_tensors.get(key, use_all_gather)
|
||||
return use_all_gather
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
@@ -808,6 +830,35 @@ class GroupCoordinator:
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return tensor_dict
|
||||
handles = self.isend_tensor_dict(
|
||||
tensor_dict,
|
||||
dst=dst,
|
||||
all_gather_group=all_gather_group,
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
return None
|
||||
|
||||
def isend_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int | None = None,
|
||||
all_gather_group: "GroupCoordinator | None" = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> list[Handle]:
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
# custom device communicator path is synchronous
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst
|
||||
)
|
||||
return []
|
||||
|
||||
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
||||
all_gather_rank = (
|
||||
0 if all_gather_group is None else all_gather_group.rank_in_group
|
||||
@@ -820,53 +871,31 @@ class GroupCoordinator:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
self.device_communicator.send_tensor_dict( # type: ignore
|
||||
tensor_dict, dst
|
||||
)
|
||||
return None
|
||||
|
||||
metadata_list: list[tuple[Any, Any]] = []
|
||||
assert isinstance(tensor_dict, dict), (
|
||||
f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
)
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `send_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
|
||||
assert len(tensor_keys) == len(tensor_list)
|
||||
|
||||
handles: list[Handle] = []
|
||||
for key, tensor in zip(tensor_keys, tensor_list):
|
||||
if tensor.numel() == 0:
|
||||
# Skip sending empty tensors.
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (
|
||||
all_gather_group is not None and tensor.numel() % all_gather_size == 0
|
||||
)
|
||||
use_all_gather = (
|
||||
all_gather_tensors.get(key, use_all_gather)
|
||||
if all_gather_tensors
|
||||
else use_all_gather
|
||||
)
|
||||
if use_all_gather:
|
||||
if self._should_use_all_gather(
|
||||
key, tensor.numel(), all_gather_group, all_gather_tensors
|
||||
):
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.send(
|
||||
tensor, dst=self.ranks[dst], group=metadata_group
|
||||
)
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
||||
return None
|
||||
comm_group = metadata_group if tensor.is_cpu else group
|
||||
handle = torch.distributed.isend(
|
||||
tensor, dst=self.ranks[dst], group=comm_group
|
||||
)
|
||||
if tensor.is_cuda:
|
||||
tensor.record_stream(torch.cuda.current_stream(tensor.device))
|
||||
handles.append(handle)
|
||||
|
||||
return handles
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
@@ -895,6 +924,38 @@ class GroupCoordinator:
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None
|
||||
tensor_dict, handles, postprocess = self.irecv_tensor_dict(
|
||||
src=src,
|
||||
all_gather_group=all_gather_group,
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
for fn in postprocess:
|
||||
fn()
|
||||
return tensor_dict
|
||||
|
||||
def irecv_tensor_dict(
|
||||
self,
|
||||
src: int | None = None,
|
||||
all_gather_group: "GroupCoordinator | None" = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> tuple[
|
||||
dict[str, torch.Tensor | Any] | None,
|
||||
list[Handle],
|
||||
list[Callable[[], None]],
|
||||
]:
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None, [], []
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
# custom device communicator path is synchronous
|
||||
sync_tensor_dict = self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src
|
||||
)
|
||||
return sync_tensor_dict, [], []
|
||||
|
||||
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
||||
all_gather_rank = (
|
||||
0 if all_gather_group is None else all_gather_group.rank_in_group
|
||||
@@ -907,57 +968,57 @@ class GroupCoordinator:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
return self.device_communicator.recv_tensor_dict( # type: ignore
|
||||
src
|
||||
)
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
handles: list[Handle] = []
|
||||
postprocess: list[Callable[[], None]] = []
|
||||
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
|
||||
if tensor.numel() == 0:
|
||||
# Skip broadcasting empty tensors.
|
||||
tensor_dict[key] = tensor
|
||||
full_tensor = torch.empty(
|
||||
value.size, dtype=value.dtype, device=value.device
|
||||
)
|
||||
if full_tensor.numel() == 0:
|
||||
tensor_dict[key] = full_tensor
|
||||
continue
|
||||
|
||||
# send-allgather: send only a slice, then do allgather.
|
||||
use_all_gather = (
|
||||
all_gather_group is not None
|
||||
and tensor.numel() % all_gather_size == 0
|
||||
)
|
||||
use_all_gather = (
|
||||
all_gather_tensors.get(key, use_all_gather)
|
||||
if all_gather_tensors
|
||||
else use_all_gather
|
||||
)
|
||||
|
||||
if use_all_gather:
|
||||
orig_shape = tensor.shape
|
||||
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
||||
|
||||
if tensor.is_cpu:
|
||||
# use metadata_group for CPU tensors
|
||||
torch.distributed.recv(
|
||||
tensor, src=self.ranks[src], group=metadata_group
|
||||
if self._should_use_all_gather(
|
||||
key, full_tensor.numel(), all_gather_group, all_gather_tensors
|
||||
):
|
||||
orig_shape = full_tensor.shape
|
||||
slice_tensor = full_tensor.reshape(all_gather_size, -1)[
|
||||
all_gather_rank
|
||||
]
|
||||
comm_group = metadata_group if slice_tensor.is_cpu else group
|
||||
handle = torch.distributed.irecv(
|
||||
slice_tensor, src=self.ranks[src], group=comm_group
|
||||
)
|
||||
handles.append(handle)
|
||||
|
||||
def _postprocess(
|
||||
key: str = key,
|
||||
slice_tensor: torch.Tensor = slice_tensor,
|
||||
orig_shape: tuple[int, ...] = tuple(orig_shape),
|
||||
all_gather_group=all_gather_group,
|
||||
) -> None:
|
||||
assert all_gather_group is not None
|
||||
tensor_dict[key] = all_gather_group.all_gather(
|
||||
slice_tensor, dim=0
|
||||
).reshape(orig_shape)
|
||||
|
||||
postprocess.append(_postprocess)
|
||||
tensor_dict[key] = slice_tensor
|
||||
else:
|
||||
# use group for GPU tensors
|
||||
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
|
||||
if use_all_gather:
|
||||
# do the allgather
|
||||
tensor = all_gather_group.all_gather( # type: ignore
|
||||
tensor, dim=0
|
||||
comm_group = metadata_group if full_tensor.is_cpu else group
|
||||
handle = torch.distributed.irecv(
|
||||
full_tensor, src=self.ranks[src], group=comm_group
|
||||
)
|
||||
tensor = tensor.reshape(orig_shape)
|
||||
|
||||
tensor_dict[key] = tensor
|
||||
handles.append(handle)
|
||||
tensor_dict[key] = full_tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
return tensor_dict, handles, postprocess
|
||||
|
||||
def barrier(self):
|
||||
"""Barrier synchronization among the group.
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
@@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import (
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
Handle,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
@@ -68,6 +70,38 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
class AsyncIntermediateTensors(IntermediateTensors):
|
||||
"""IntermediateTensors with lazy comm synchronization"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensors: dict[str, torch.Tensor],
|
||||
comm_handles: list[Handle] | None = None,
|
||||
comm_postprocess: list[Callable[[], None]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(tensors)
|
||||
self._comm_handles = comm_handles
|
||||
self._comm_postprocess = comm_postprocess
|
||||
self._comm_waited = False
|
||||
|
||||
def wait_for_comm(self) -> None:
|
||||
if self._comm_waited:
|
||||
return
|
||||
if self._comm_handles:
|
||||
for handle in self._comm_handles:
|
||||
handle.wait()
|
||||
if self._comm_postprocess:
|
||||
for fn in self._comm_postprocess:
|
||||
fn()
|
||||
self._comm_waited = True
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
# ensure `.tensors` is ready before use
|
||||
if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
|
||||
object.__getattribute__(self, "wait_for_comm")()
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -113,6 +147,8 @@ class Worker(WorkerBase):
|
||||
raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
|
||||
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
# pending non-blocking PP send work from the previous iteration
|
||||
self._pp_send_work: list[Handle] = []
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
@@ -600,6 +636,12 @@ class Worker(WorkerBase):
|
||||
def execute_model(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
# ensure any previous non-blocking PP sends are complete
|
||||
if self._pp_send_work:
|
||||
for handle in self._pp_send_work:
|
||||
handle.wait()
|
||||
self._pp_send_work = []
|
||||
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@@ -637,12 +679,18 @@ class Worker(WorkerBase):
|
||||
}
|
||||
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
tensor_dict = get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
tensor_dict, comm_handles, comm_postprocess = (
|
||||
get_pp_group().irecv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
)
|
||||
assert tensor_dict is not None
|
||||
intermediate_tensors = IntermediateTensors(tensor_dict)
|
||||
intermediate_tensors = AsyncIntermediateTensors(
|
||||
tensor_dict,
|
||||
comm_handles=comm_handles,
|
||||
comm_postprocess=comm_postprocess,
|
||||
)
|
||||
|
||||
with self.annotate_profile(scheduler_output):
|
||||
output = self.model_runner.execute_model(
|
||||
@@ -660,7 +708,8 @@ class Worker(WorkerBase):
|
||||
and not get_pp_group().is_last_rank
|
||||
)
|
||||
|
||||
get_pp_group().send_tensor_dict(
|
||||
# launch non-blocking send of intermediate tensors
|
||||
self._pp_send_work = get_pp_group().isend_tensor_dict(
|
||||
output.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
|
||||
Reference in New Issue
Block a user