diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 8bc361741..b09d5f44d 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -286,7 +286,10 @@ class DeviceCommunicatorBase: router_logits: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): """ Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class. diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index fdfb74d7a..2a84418c1 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -8,11 +8,14 @@ import torch from torch.distributed import ProcessGroup from vllm.distributed.utils import pickle +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class CpuCommunicator(DeviceCommunicatorBase): def __init__( @@ -32,6 +35,20 @@ class CpuCommunicator(DeviceCommunicatorBase): ): self.dist_module = _CPUSHMDistributed(self) + if self.use_all2all: + if self.all2all_backend != "naive": # type: ignore[has-type] + logger.warning( + "`%s` all2all manager is not supported on CPU. " + "Falling back to `naive` all2all manager for CPU.", + self.all2all_backend, # type: ignore[has-type] + ) + self.all2all_backend = "naive" + if self.all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + def all_reduce(self, input_): self.dist_module.all_reduce(input_, group=self.device_group) return input_ @@ -110,6 +127,30 @@ class CpuCommunicator(DeviceCommunicatorBase): ) -> dict[str, torch.Tensor | Any]: return self.dist_module.recv_tensor_dict(src) + def dispatch( # type: ignore[override] + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert self.all2all_manager is not None + return self.all2all_manager.dispatch( + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, # type: ignore[call-arg] + ) + + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine( + hidden_states, is_sequence_parallel + ) + return hidden_states + class _CPUSHMDistributed: def __init__(self, communicator: CpuCommunicator):