[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -19,6 +19,8 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
)
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata
|
||||
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
|
||||
|
||||
from ..utils import (
|
||||
init_test_distributed_environment,
|
||||
@@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker(
|
||||
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
class _DummyWork:
|
||||
def __init__(self) -> None:
|
||||
self.wait_calls = 0
|
||||
|
||||
def wait(self) -> None:
|
||||
self.wait_calls += 1
|
||||
|
||||
|
||||
class _DummyAllGatherGroup:
|
||||
def __init__(self, world_size: int, rank_in_group: int) -> None:
|
||||
self.world_size = world_size
|
||||
self.rank_in_group = rank_in_group
|
||||
|
||||
def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor:
|
||||
# duplicate local slice across ranks.
|
||||
assert dim == 0
|
||||
return torch.cat([t for _ in range(self.world_size)], dim=0)
|
||||
|
||||
|
||||
def _make_group_for_unit_test(
|
||||
rank_in_group: int = 0, world_size: int = 2
|
||||
) -> GroupCoordinator:
|
||||
# avoid running GroupCoordinator.__init__ (it wires up real process groups).
|
||||
g = GroupCoordinator.__new__(GroupCoordinator)
|
||||
g.world_size = world_size
|
||||
g.rank_in_group = rank_in_group
|
||||
g.ranks = list(range(world_size))
|
||||
g.use_cpu_custom_send_recv = False
|
||||
g.device_group = None
|
||||
g.cpu_group = None
|
||||
return g
|
||||
|
||||
|
||||
def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork:
|
||||
t.fill_(1)
|
||||
return _DummyWork()
|
||||
|
||||
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
|
||||
monkeypatch.setattr(torch.distributed, "irecv", fake_irecv)
|
||||
|
||||
g = _make_group_for_unit_test(rank_in_group=0, world_size=2)
|
||||
# 2 tensors so we can catch late-binding bugs in postprocess closures.
|
||||
metadata_list = [
|
||||
("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
|
||||
("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))),
|
||||
]
|
||||
g.recv_object = lambda src=None: metadata_list # type: ignore[method-assign]
|
||||
|
||||
ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0)
|
||||
td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag)
|
||||
|
||||
assert td is not None
|
||||
assert len(handles) == 2
|
||||
assert len(postprocess) == 2
|
||||
|
||||
# before postprocess, dict holds the TP slice (shape 2).
|
||||
assert td["a"].shape == torch.Size([2])
|
||||
assert td["b"].shape == torch.Size([2])
|
||||
|
||||
# simulate worker-side "defer wait": wait + postprocess later.
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
for fn in postprocess:
|
||||
fn()
|
||||
|
||||
# after postprocess, dict values are reconstructed to full shape (shape 4),
|
||||
# and each key should be updated independently
|
||||
assert td["a"].shape == torch.Size([4])
|
||||
assert td["b"].shape == torch.Size([4])
|
||||
torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32))
|
||||
torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32))
|
||||
|
||||
|
||||
def test_async_intermediate_tensors_lazy_wait() -> None:
|
||||
work = _DummyWork()
|
||||
post_calls = {"n": 0}
|
||||
|
||||
def post() -> None:
|
||||
post_calls["n"] += 1
|
||||
|
||||
it = AsyncIntermediateTensors(
|
||||
{"x": torch.tensor([1])},
|
||||
comm_handles=[work],
|
||||
comm_postprocess=[post],
|
||||
)
|
||||
|
||||
# accessing non-tensor attributes should not trigger wait.
|
||||
assert it.kv_connector_output is None
|
||||
assert work.wait_calls == 0
|
||||
assert post_calls["n"] == 0
|
||||
|
||||
# first access of `.tensors` triggers wait + postprocess.
|
||||
_ = it.tensors
|
||||
assert work.wait_calls == 1
|
||||
assert post_calls["n"] == 1
|
||||
|
||||
# subsequent access should not re-wait.
|
||||
_ = it.tensors
|
||||
assert work.wait_calls == 1
|
||||
assert post_calls["n"] == 1
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_test_worker(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
Reference in New Issue
Block a user