[CPU][Distributed] Fix Enable _CPUSHMDistributed only when TP/PP ranks share the same SHM group name (#34169)
Signed-off-by: Charles Ashby <charlesa.l@hotmail.com>
This commit is contained in:
@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
)
|
||||
and hasattr(torch.ops._C, "init_shm_manager")
|
||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||
and self._all_group_ranks_share_shm_group_name()
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
|
||||
logger.info(
|
||||
"CPU SHM communicator disabled for group %s: ranks do not share "
|
||||
"the same SHM group name, falling back to torch.distributed.",
|
||||
unique_name,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def _all_group_ranks_share_shm_group_name(self) -> bool:
|
||||
"""
|
||||
CPUSHM requires all ranks in this group to agree on one SHM group name.
|
||||
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
|
||||
"""
|
||||
local_name = _CPUSHMDistributed.make_group_name(self)
|
||||
names: list[str] = [""] * self.world_size
|
||||
torch.distributed.all_gather_object(
|
||||
names,
|
||||
local_name,
|
||||
group=self.device_group,
|
||||
)
|
||||
return len(set(names)) == 1
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
self.communicator = communicator
|
||||
|
||||
self.group_name = self.make_group_name(communicator)
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
@staticmethod
|
||||
def make_group_name(communicator: CpuCommunicator) -> str:
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
group_ranks = [str(rank) for rank in communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
thread_num_tensor = torch.tensor(
|
||||
|
||||
Reference in New Issue
Block a user