[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Wentao Ye
2026-02-13 03:38:16 -05:00
committed by GitHub
parent dddbff4624
commit 3d2a026fd0
3 changed files with 298 additions and 81 deletions

View File

@@ -19,6 +19,8 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter,
)
from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
from ..utils import (
init_test_distributed_environment,
@@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker(
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
class _DummyWork:
def __init__(self) -> None:
self.wait_calls = 0
def wait(self) -> None:
self.wait_calls += 1
class _DummyAllGatherGroup:
def __init__(self, world_size: int, rank_in_group: int) -> None:
self.world_size = world_size
self.rank_in_group = rank_in_group
def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor:
# duplicate local slice across ranks.
assert dim == 0
return torch.cat([t for _ in range(self.world_size)], dim=0)
def _make_group_for_unit_test(
rank_in_group: int = 0, world_size: int = 2
) -> GroupCoordinator:
# avoid running GroupCoordinator.__init__ (it wires up real process groups).
g = GroupCoordinator.__new__(GroupCoordinator)
g.world_size = world_size
g.rank_in_group = rank_in_group
g.ranks = list(range(world_size))
g.use_cpu_custom_send_recv = False
g.device_group = None
g.cpu_group = None
return g
def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork:
t.fill_(1)
return _DummyWork()
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
monkeypatch.setattr(torch.distributed, "irecv", fake_irecv)
g = _make_group_for_unit_test(rank_in_group=0, world_size=2)
# 2 tensors so we can catch late-binding bugs in postprocess closures.
metadata_list = [
("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
]
g.recv_object = lambda src=None: metadata_list # type: ignore[method-assign]
ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0)
td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag)
assert td is not None
assert len(handles) == 2
assert len(postprocess) == 2
# before postprocess, dict holds the TP slice (shape 2).
assert td["a"].shape == torch.Size([2])
assert td["b"].shape == torch.Size([2])
# simulate worker-side "defer wait": wait + postprocess later.
for handle in handles:
handle.wait()
for fn in postprocess:
fn()
# after postprocess, dict values are reconstructed to full shape (shape 4),
# and each key should be updated independently
assert td["a"].shape == torch.Size([4])
assert td["b"].shape == torch.Size([4])
torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32))
torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32))
def test_async_intermediate_tensors_lazy_wait() -> None:
work = _DummyWork()
post_calls = {"n": 0}
def post() -> None:
post_calls["n"] += 1
it = AsyncIntermediateTensors(
{"x": torch.tensor([1])},
comm_handles=[work],
comm_postprocess=[post],
)
# accessing non-tensor attributes should not trigger wait.
assert it.kv_connector_output is None
assert work.wait_calls == 0
assert post_calls["n"] == 0
# first access of `.tensors` triggers wait + postprocess.
_ = it.tensors
assert work.wait_calls == 1
assert post_calls["n"] == 1
# subsequent access should not re-wait.
_ = it.tensors
assert work.wait_calls == 1
assert post_calls["n"] == 1
@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(
monkeypatch: pytest.MonkeyPatch,