diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index ba80ee6fb..ce4c9c24e 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -19,6 +19,8 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter, ) +from vllm.distributed.parallel_state import GroupCoordinator, TensorMetadata +from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors from ..utils import ( init_test_distributed_environment, @@ -200,6 +202,111 @@ def send_recv_tensor_dict_test_worker( torch.testing.assert_close(recv_dict["f"], test_dict["f"]) +class _DummyWork: + def __init__(self) -> None: + self.wait_calls = 0 + + def wait(self) -> None: + self.wait_calls += 1 + + +class _DummyAllGatherGroup: + def __init__(self, world_size: int, rank_in_group: int) -> None: + self.world_size = world_size + self.rank_in_group = rank_in_group + + def all_gather(self, t: torch.Tensor, dim: int = 0) -> torch.Tensor: + # duplicate local slice across ranks. + assert dim == 0 + return torch.cat([t for _ in range(self.world_size)], dim=0) + + +def _make_group_for_unit_test( + rank_in_group: int = 0, world_size: int = 2 +) -> GroupCoordinator: + # avoid running GroupCoordinator.__init__ (it wires up real process groups). + g = GroupCoordinator.__new__(GroupCoordinator) + g.world_size = world_size + g.rank_in_group = rank_in_group + g.ranks = list(range(world_size)) + g.use_cpu_custom_send_recv = False + g.device_group = None + g.cpu_group = None + return g + + +def test_irecv_tensor_dict_send_allgather_postprocess_binds_keys( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def fake_irecv(t: torch.Tensor, *args: Any, **kwargs: Any) -> _DummyWork: + t.fill_(1) + return _DummyWork() + + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) + monkeypatch.setattr(torch.distributed, "irecv", fake_irecv) + + g = _make_group_for_unit_test(rank_in_group=0, world_size=2) + # 2 tensors so we can catch late-binding bugs in postprocess closures. + metadata_list = [ + ("a", TensorMetadata("cpu", torch.int32, torch.Size([4]))), + ("b", TensorMetadata("cpu", torch.int32, torch.Size([4]))), + ] + g.recv_object = lambda src=None: metadata_list # type: ignore[method-assign] + + ag = _DummyAllGatherGroup(world_size=2, rank_in_group=0) + td, handles, postprocess = g.irecv_tensor_dict(all_gather_group=ag) + + assert td is not None + assert len(handles) == 2 + assert len(postprocess) == 2 + + # before postprocess, dict holds the TP slice (shape 2). + assert td["a"].shape == torch.Size([2]) + assert td["b"].shape == torch.Size([2]) + + # simulate worker-side "defer wait": wait + postprocess later. + for handle in handles: + handle.wait() + for fn in postprocess: + fn() + + # after postprocess, dict values are reconstructed to full shape (shape 4), + # and each key should be updated independently + assert td["a"].shape == torch.Size([4]) + assert td["b"].shape == torch.Size([4]) + torch.testing.assert_close(td["a"], torch.ones(4, dtype=torch.int32)) + torch.testing.assert_close(td["b"], torch.ones(4, dtype=torch.int32)) + + +def test_async_intermediate_tensors_lazy_wait() -> None: + work = _DummyWork() + post_calls = {"n": 0} + + def post() -> None: + post_calls["n"] += 1 + + it = AsyncIntermediateTensors( + {"x": torch.tensor([1])}, + comm_handles=[work], + comm_postprocess=[post], + ) + + # accessing non-tensor attributes should not trigger wait. + assert it.kv_connector_output is None + assert work.wait_calls == 0 + assert post_calls["n"] == 0 + + # first access of `.tensors` triggers wait + postprocess. + _ = it.tensors + assert work.wait_calls == 1 + assert post_calls["n"] == 1 + + # subsequent access should not re-wait. + _ = it.tensors + assert work.wait_calls == 1 + assert post_calls["n"] == 1 + + @ray.remote(num_gpus=1, max_calls=1) def send_recv_test_worker( monkeypatch: pytest.MonkeyPatch, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b8b2607ff..9994096bf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -33,7 +33,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta from multiprocessing import shared_memory -from typing import Any +from typing import Any, Protocol from unittest.mock import patch import torch @@ -64,6 +64,14 @@ class GraphCaptureContext: TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) +class Handle(Protocol): + """Minimal async work handle used by P2P send/recv methods.""" + + def is_completed(self) -> bool: ... + + def wait(self) -> None: ... + + def _split_tensor_dict( tensor_dict: dict[str, torch.Tensor | Any], ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: @@ -780,6 +788,20 @@ class GroupCoordinator: async_handle.wait() return tensor_dict + def _should_use_all_gather( + self, + key: str, + numel: int, + all_gather_group: "GroupCoordinator | None", + all_gather_tensors: dict[str, bool] | None, + ) -> bool: + if all_gather_group is None: + return False + use_all_gather = numel % all_gather_group.world_size == 0 + if all_gather_tensors is not None: + use_all_gather = all_gather_tensors.get(key, use_all_gather) + return use_all_gather + def send_tensor_dict( self, tensor_dict: dict[str, torch.Tensor | Any], @@ -808,6 +830,35 @@ class GroupCoordinator: # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict + handles = self.isend_tensor_dict( + tensor_dict, + dst=dst, + all_gather_group=all_gather_group, + all_gather_tensors=all_gather_tensors, + ) + for handle in handles: + handle.wait() + return None + + def isend_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + all_gather_group: "GroupCoordinator | None" = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> list[Handle]: + if self.world_size <= 1: + return [] + + if self.use_cpu_custom_send_recv: + if self.device_communicator is None: + raise ValueError("No device communicator found") + # custom device communicator path is synchronous + self.device_communicator.send_tensor_dict( # type: ignore + tensor_dict, dst + ) + return [] + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_rank = ( 0 if all_gather_group is None else all_gather_group.rank_in_group @@ -820,53 +871,31 @@ class GroupCoordinator: dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" - if self.use_cpu_custom_send_recv: - if self.device_communicator is None: - raise ValueError("No device communicator found") - self.device_communicator.send_tensor_dict( # type: ignore - tensor_dict, dst - ) - return None - - metadata_list: list[tuple[Any, Any]] = [] - assert isinstance(tensor_dict, dict), ( - f"Expecting a dictionary, got {type(tensor_dict)}" - ) metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `send_object_list` has serialization & deserialization, - # all happening on CPU. Therefore, we can use the CPU group. self.send_object(metadata_list, dst=dst) tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)] assert len(tensor_keys) == len(tensor_list) + handles: list[Handle] = [] for key, tensor in zip(tensor_keys, tensor_list): if tensor.numel() == 0: - # Skip sending empty tensors. continue - # send-allgather: send only a slice, then do allgather. - use_all_gather = ( - all_gather_group is not None and tensor.numel() % all_gather_size == 0 - ) - use_all_gather = ( - all_gather_tensors.get(key, use_all_gather) - if all_gather_tensors - else use_all_gather - ) - if use_all_gather: + if self._should_use_all_gather( + key, tensor.numel(), all_gather_group, all_gather_tensors + ): tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.send( - tensor, dst=self.ranks[dst], group=metadata_group - ) - else: - # use group for GPU tensors - torch.distributed.send(tensor, dst=self.ranks[dst], group=group) - return None + comm_group = metadata_group if tensor.is_cpu else group + handle = torch.distributed.isend( + tensor, dst=self.ranks[dst], group=comm_group + ) + if tensor.is_cuda: + tensor.record_stream(torch.cuda.current_stream(tensor.device)) + handles.append(handle) + + return handles def recv_tensor_dict( self, @@ -895,6 +924,38 @@ class GroupCoordinator: # Bypass the function if we are using only 1 GPU. if not torch.distributed.is_initialized() or self.world_size == 1: return None + tensor_dict, handles, postprocess = self.irecv_tensor_dict( + src=src, + all_gather_group=all_gather_group, + all_gather_tensors=all_gather_tensors, + ) + for handle in handles: + handle.wait() + for fn in postprocess: + fn() + return tensor_dict + + def irecv_tensor_dict( + self, + src: int | None = None, + all_gather_group: "GroupCoordinator | None" = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> tuple[ + dict[str, torch.Tensor | Any] | None, + list[Handle], + list[Callable[[], None]], + ]: + if not torch.distributed.is_initialized() or self.world_size == 1: + return None, [], [] + if self.use_cpu_custom_send_recv: + if self.device_communicator is None: + raise ValueError("No device communicator found") + # custom device communicator path is synchronous + sync_tensor_dict = self.device_communicator.recv_tensor_dict( # type: ignore + src + ) + return sync_tensor_dict, [], [] + all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size all_gather_rank = ( 0 if all_gather_group is None else all_gather_group.rank_in_group @@ -907,57 +968,57 @@ class GroupCoordinator: src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" - if self.use_cpu_custom_send_recv: - if self.device_communicator is None: - raise ValueError("No device communicator found") - return self.device_communicator.recv_tensor_dict( # type: ignore - src - ) - recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} + handles: list[Handle] = [] + postprocess: list[Callable[[], None]] = [] + for key, value in recv_metadata_list: if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor + full_tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if full_tensor.numel() == 0: + tensor_dict[key] = full_tensor continue - # send-allgather: send only a slice, then do allgather. - use_all_gather = ( - all_gather_group is not None - and tensor.numel() % all_gather_size == 0 - ) - use_all_gather = ( - all_gather_tensors.get(key, use_all_gather) - if all_gather_tensors - else use_all_gather - ) - - if use_all_gather: - orig_shape = tensor.shape - tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] - - if tensor.is_cpu: - # use metadata_group for CPU tensors - torch.distributed.recv( - tensor, src=self.ranks[src], group=metadata_group + if self._should_use_all_gather( + key, full_tensor.numel(), all_gather_group, all_gather_tensors + ): + orig_shape = full_tensor.shape + slice_tensor = full_tensor.reshape(all_gather_size, -1)[ + all_gather_rank + ] + comm_group = metadata_group if slice_tensor.is_cpu else group + handle = torch.distributed.irecv( + slice_tensor, src=self.ranks[src], group=comm_group ) + handles.append(handle) + + def _postprocess( + key: str = key, + slice_tensor: torch.Tensor = slice_tensor, + orig_shape: tuple[int, ...] = tuple(orig_shape), + all_gather_group=all_gather_group, + ) -> None: + assert all_gather_group is not None + tensor_dict[key] = all_gather_group.all_gather( + slice_tensor, dim=0 + ).reshape(orig_shape) + + postprocess.append(_postprocess) + tensor_dict[key] = slice_tensor else: - # use group for GPU tensors - torch.distributed.recv(tensor, src=self.ranks[src], group=group) - if use_all_gather: - # do the allgather - tensor = all_gather_group.all_gather( # type: ignore - tensor, dim=0 + comm_group = metadata_group if full_tensor.is_cpu else group + handle = torch.distributed.irecv( + full_tensor, src=self.ranks[src], group=comm_group ) - tensor = tensor.reshape(orig_shape) - - tensor_dict[key] = tensor + handles.append(handle) + tensor_dict[key] = full_tensor else: tensor_dict[key] = value - return tensor_dict + + return tensor_dict, handles, postprocess def barrier(self): """Barrier synchronization among the group. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 2507b7f20..e35d0ef68 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -4,6 +4,7 @@ import gc import os +from collections.abc import Callable from contextlib import AbstractContextManager, nullcontext from types import NoneType from typing import TYPE_CHECKING, Any, cast @@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import ( has_kv_transfer_group, ) from vllm.distributed.parallel_state import ( + Handle, get_pcp_group, get_pp_group, get_tp_group, @@ -68,6 +70,38 @@ if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner +class AsyncIntermediateTensors(IntermediateTensors): + """IntermediateTensors with lazy comm synchronization""" + + def __init__( + self, + tensors: dict[str, torch.Tensor], + comm_handles: list[Handle] | None = None, + comm_postprocess: list[Callable[[], None]] | None = None, + ) -> None: + super().__init__(tensors) + self._comm_handles = comm_handles + self._comm_postprocess = comm_postprocess + self._comm_waited = False + + def wait_for_comm(self) -> None: + if self._comm_waited: + return + if self._comm_handles: + for handle in self._comm_handles: + handle.wait() + if self._comm_postprocess: + for fn in self._comm_postprocess: + fn() + self._comm_waited = True + + def __getattribute__(self, name: str): + # ensure `.tensors` is ready before use + if name == "tensors" and not object.__getattribute__(self, "_comm_waited"): + object.__getattribute__(self, "wait_for_comm")() + return object.__getattribute__(self, name) + + class Worker(WorkerBase): def __init__( self, @@ -113,6 +147,8 @@ class Worker(WorkerBase): raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}") self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER + # pending non-blocking PP send work from the previous iteration + self._pp_send_work: list[Handle] = [] def sleep(self, level: int = 1) -> None: from vllm.device_allocator.cumem import CuMemAllocator @@ -600,6 +636,12 @@ class Worker(WorkerBase): def execute_model( self, scheduler_output: "SchedulerOutput" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: + # ensure any previous non-blocking PP sends are complete + if self._pp_send_work: + for handle in self._pp_send_work: + handle.wait() + self._pp_send_work = [] + intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -637,12 +679,18 @@ class Worker(WorkerBase): } if forward_pass and not get_pp_group().is_first_rank: - tensor_dict = get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors, + tensor_dict, comm_handles, comm_postprocess = ( + get_pp_group().irecv_tensor_dict( + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) ) assert tensor_dict is not None - intermediate_tensors = IntermediateTensors(tensor_dict) + intermediate_tensors = AsyncIntermediateTensors( + tensor_dict, + comm_handles=comm_handles, + comm_postprocess=comm_postprocess, + ) with self.annotate_profile(scheduler_output): output = self.model_runner.execute_model( @@ -660,7 +708,8 @@ class Worker(WorkerBase): and not get_pp_group().is_last_rank ) - get_pp_group().send_tensor_dict( + # launch non-blocking send of intermediate tensors + self._pp_send_work = get_pp_group().isend_tensor_dict( output.tensors, all_gather_group=get_tp_group(), all_gather_tensors=all_gather_tensors,