[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Wentao Ye
2026-02-13 03:38:16 -05:00
committed by GitHub
parent dddbff4624
commit 3d2a026fd0
3 changed files with 298 additions and 81 deletions

View File

@@ -19,6 +19,8 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
from ..utils import (
init_test_distributed_environment,
@@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker(
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
class _DummyWork:
def __init__(self) -> None:
self.wait_calls = 0
def wait(self) -> None:
self.wait_calls += 1
class _DummyAllGatherGroup:
def __init__(self, world_size: int, rank_in_group: int) -> None:
self.world_size = world_size
self.rank_in_group = rank_in_group
def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor:
# duplicate local slice across ranks.
assert dim == 0
return torch.cat([t for _ in range(self.world_size)], dim=0)
def _make_group_for_unit_test(
rank_in_group: int = 0, world_size: int = 2
) -> GroupCoordinator:
# avoid running GroupCoordinator.__init__ (it wires up real process groups).
g = GroupCoordinator.__new__(GroupCoordinator)
g.world_size = world_size
g.rank_in_group = rank_in_group
g.ranks = list(range(world_size))
g.use_cpu_custom_send_recv = False
g.device_group = None
g.cpu_group = None
return g
def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork:
t.fill_(1)
return _DummyWork()
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
monkeypatch.setattr(torch.distributed, "irecv", fake_irecv)
g = _make_group_for_unit_test(rank_in_group=0, world_size=2)
# 2 tensors so we can catch late-binding bugs in postprocess closures.
metadata_list = [
("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
]
g.recv_object = lambda src=None: metadata_list # type: ignore[method-assign]
ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0)
td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag)
assert td is not None
assert len(handles) == 2
assert len(postprocess) == 2
# before postprocess, dict holds the TP slice (shape 2).
assert td["a"].shape == torch.Size([2])
assert td["b"].shape == torch.Size([2])
# simulate worker-side "defer wait": wait + postprocess later.
for handle in handles:
handle.wait()
for fn in postprocess:
fn()
# after postprocess, dict values are reconstructed to full shape (shape 4),
# and each key should be updated independently
assert td["a"].shape == torch.Size([4])
assert td["b"].shape == torch.Size([4])
torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32))
torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32))
def test_async_intermediate_tensors_lazy_wait() -> None:
work = _DummyWork()
post_calls = {"n": 0}
def post() -> None:
post_calls["n"] += 1
it = AsyncIntermediateTensors(
{"x": torch.tensor([1])},
comm_handles=[work],
comm_postprocess=[post],
)
# accessing non-tensor attributes should not trigger wait.
assert it.kv_connector_output is None
assert work.wait_calls == 0
assert post_calls["n"] == 0
# first access of `.tensors` triggers wait + postprocess.
_ = it.tensors
assert work.wait_calls == 1
assert post_calls["n"] == 1
# subsequent access should not re-wait.
_ = it.tensors
assert work.wait_calls == 1
assert post_calls["n"] == 1
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(
monkeypatch: pytest.MonkeyPatch,

View File

@@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any
from typing import Any, Protocol
from unittest.mock import patch
import torch
@@ -64,6 +64,14 @@ class GraphCaptureContext:
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
class Handle(Protocol):
"""Minimal async work handle used by P2P send/recv methods."""
def is_completed(self) -> bool: ...
def wait(self) -> None: ...
def _split_tensor_dict(
tensor_dict: dict[str, torch.Tensor | Any],
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
@@ -780,6 +788,20 @@ class GroupCoordinator:
async_handle.wait()
return tensor_dict
def _should_use_all_gather(
self,
key: str,
numel: int,
all_gather_group: "GroupCoordinator | None",
all_gather_tensors: dict[str, bool] | None,
) -> bool:
if all_gather_group is None:
return False
use_all_gather = numel % all_gather_group.world_size == 0
if all_gather_tensors is not None:
use_all_gather = all_gather_tensors.get(key, use_all_gather)
return use_all_gather
def send_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
@@ -808,6 +830,35 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict
handles = self.isend_tensor_dict(
tensor_dict,
dst=dst,
all_gather_group=all_gather_group,
all_gather_tensors=all_gather_tensors,
)
for handle in handles:
handle.wait()
return None
def isend_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> list[Handle]:
if self.world_size <= 1:
return []
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
# custom device communicator path is synchronous
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst
)
return []
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group
@@ -820,53 +871,31 @@ class GroupCoordinator:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
self.device_communicator.send_tensor_dict( # type: ignore
tensor_dict, dst
)
return None
metadata_list: list[tuple[Any, Any]] = []
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
assert len(tensor_keys) == len(tensor_list)
handles: list[Handle] = []
for key, tensor in zip(tensor_keys, tensor_list):
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (
all_gather_group is not None and tensor.numel() % all_gather_size == 0
)
use_all_gather = (
all_gather_tensors.get(key, use_all_gather)
if all_gather_tensors
else use_all_gather
)
if use_all_gather:
if self._should_use_all_gather(
key, tensor.numel(), all_gather_group, all_gather_tensors
):
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(
tensor, dst=self.ranks[dst], group=metadata_group
)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
return None
comm_group = metadata_group if tensor.is_cpu else group
handle = torch.distributed.isend(
tensor, dst=self.ranks[dst], group=comm_group
)
if tensor.is_cuda:
tensor.record_stream(torch.cuda.current_stream(tensor.device))
handles.append(handle)
return handles
def recv_tensor_dict(
self,
@@ -895,6 +924,38 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None
tensor_dict, handles, postprocess = self.irecv_tensor_dict(
src=src,
all_gather_group=all_gather_group,
all_gather_tensors=all_gather_tensors,
)
for handle in handles:
handle.wait()
for fn in postprocess:
fn()
return tensor_dict
def irecv_tensor_dict(
self,
src: int | None = None,
all_gather_group: "GroupCoordinator | None" = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> tuple[
dict[str, torch.Tensor | Any] | None,
list[Handle],
list[Callable[[], None]],
]:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None, [], []
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
# custom device communicator path is synchronous
sync_tensor_dict = self.device_communicator.recv_tensor_dict( # type: ignore
src
)
return sync_tensor_dict, [], []
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
all_gather_rank = (
0 if all_gather_group is None else all_gather_group.rank_in_group
@@ -907,57 +968,57 @@ class GroupCoordinator:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.recv_tensor_dict( # type: ignore
src
)
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
handles: list[Handle] = []
postprocess: list[Callable[[], None]] = []
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
full_tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
if full_tensor.numel() == 0:
tensor_dict[key] = full_tensor
continue
# send-allgather: send only a slice, then do allgather.
use_all_gather = (
all_gather_group is not None
and tensor.numel() % all_gather_size == 0
)
use_all_gather = (
all_gather_tensors.get(key, use_all_gather)
if all_gather_tensors
else use_all_gather
)
if use_all_gather:
orig_shape = tensor.shape
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(
tensor, src=self.ranks[src], group=metadata_group
if self._should_use_all_gather(
key, full_tensor.numel(), all_gather_group, all_gather_tensors
):
orig_shape = full_tensor.shape
slice_tensor = full_tensor.reshape(all_gather_size, -1)[
all_gather_rank
]
comm_group = metadata_group if slice_tensor.is_cpu else group
handle = torch.distributed.irecv(
slice_tensor, src=self.ranks[src], group=comm_group
)
handles.append(handle)
def _postprocess(
key: str = key,
slice_tensor: torch.Tensor = slice_tensor,
orig_shape: tuple[int, ...] = tuple(orig_shape),
all_gather_group=all_gather_group,
) -> None:
assert all_gather_group is not None
tensor_dict[key] = all_gather_group.all_gather(
slice_tensor, dim=0
).reshape(orig_shape)
postprocess.append(_postprocess)
tensor_dict[key] = slice_tensor
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
if use_all_gather:
# do the allgather
tensor = all_gather_group.all_gather( # type: ignore
tensor, dim=0
comm_group = metadata_group if full_tensor.is_cpu else group
handle = torch.distributed.irecv(
full_tensor, src=self.ranks[src], group=comm_group
)
tensor = tensor.reshape(orig_shape)
tensor_dict[key] = tensor
handles.append(handle)
tensor_dict[key] = full_tensor
else:
tensor_dict[key] = value
return tensor_dict
return tensor_dict, handles, postprocess
def barrier(self):
"""Barrier synchronization among the group.

View File

@@ -4,6 +4,7 @@
import gc
import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
@@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
Handle,
get_pcp_group,
get_pp_group,
get_tp_group,
@@ -68,6 +70,38 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class AsyncIntermediateTensors(IntermediateTensors):
"""IntermediateTensors with lazy comm synchronization"""
def __init__(
self,
tensors: dict[str, torch.Tensor],
comm_handles: list[Handle] | None = None,
comm_postprocess: list[Callable[[], None]] | None = None,
) -> None:
super().__init__(tensors)
self._comm_handles = comm_handles
self._comm_postprocess = comm_postprocess
self._comm_waited = False
def wait_for_comm(self) -> None:
if self._comm_waited:
return
if self._comm_handles:
for handle in self._comm_handles:
handle.wait()
if self._comm_postprocess:
for fn in self._comm_postprocess:
fn()
self._comm_waited = True
def __getattribute__(self, name: str):
# ensure `.tensors` is ready before use
if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
object.__getattribute__(self, "wait_for_comm")()
return object.__getattribute__(self, name)
class Worker(WorkerBase):
def __init__(
self,
@@ -113,6 +147,8 @@ class Worker(WorkerBase):
raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
# pending non-blocking PP send work from the previous iteration
self._pp_send_work: list[Handle] = []
def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
@@ -600,6 +636,12 @@ class Worker(WorkerBase):
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
# ensure any previous non-blocking PP sends are complete
if self._pp_send_work:
for handle in self._pp_send_work:
handle.wait()
self._pp_send_work = []
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -637,12 +679,18 @@ class Worker(WorkerBase):
}
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
tensor_dict, comm_handles, comm_postprocess = (
get_pp_group().irecv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
)
assert tensor_dict is not None
intermediate_tensors = IntermediateTensors(tensor_dict)
intermediate_tensors = AsyncIntermediateTensors(
tensor_dict,
comm_handles=comm_handles,
comm_postprocess=comm_postprocess,
)
with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model(
@@ -660,7 +708,8 @@ class Worker(WorkerBase):
and not get_pp_group().is_last_rank
)
get_pp_group().send_tensor_dict(
# launch non-blocking send of intermediate tensors
self._pp_send_work = get_pp_group().isend_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,