diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 23be8fcfc..2bce5faa8 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase): ) and hasattr(torch.ops._C, "init_shm_manager") and (unique_name.startswith("tp") or unique_name.startswith("pp")) + and self._all_group_ranks_share_shm_group_name() ): self.dist_module = _CPUSHMDistributed(self) + elif unique_name.startswith("tp") or unique_name.startswith("pp"): + logger.info( + "CPU SHM communicator disabled for group %s: ranks do not share " + "the same SHM group name, falling back to torch.distributed.", + unique_name, + ) if self.use_all2all: if self.all2all_backend != "naive": # type: ignore[has-type] @@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase): self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + def _all_group_ranks_share_shm_group_name(self) -> bool: + """ + CPUSHM requires all ranks in this group to agree on one SHM group name. + This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs. + """ + local_name = _CPUSHMDistributed.make_group_name(self) + names: list[str] = [""] * self.world_size + torch.distributed.all_gather_object( + names, + local_name, + group=self.device_group, + ) + return len(set(names)) == 1 + def all_reduce(self, input_): self.dist_module.all_reduce(input_, group=self.device_group) return input_ @@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase): class _CPUSHMDistributed: def __init__(self, communicator: CpuCommunicator): + self.communicator = communicator + + self.group_name = self.make_group_name(communicator) + + self.handle = self._init_cpu_shm() + + @staticmethod + def make_group_name(communicator: CpuCommunicator) -> str: instance_identifier = os.environ["VLLM_DIST_IDENT"] unique_name = communicator.unique_name instance_identifier = f"{instance_identifier}-{unique_name}" - self.communicator = communicator - - group_ranks = [str(rank) for rank in self.communicator.ranks] + group_ranks = [str(rank) for rank in communicator.ranks] shm_group_identifier = f"[{'-'.join(group_ranks)}]" - self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" - - self.handle = self._init_cpu_shm() + return f"{instance_identifier}-{shm_group_identifier}-cpushm" def _init_cpu_shm(self) -> int: thread_num_tensor = torch.tensor(