[Bugfix][CPU] Fix thread num for shared memory communication (#33317)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Li, Jiang <bigpyj64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Li, Jiang
2026-01-29 19:26:58 +08:00
committed by GitHub
parent 40c35038d2
commit 8311f083bd
3 changed files with 25 additions and 10 deletions

View File

@@ -205,10 +205,22 @@ class _CPUSHMDistributed:
self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(
[torch.get_num_threads()],
dtype=torch.int64,
)
torch.distributed.all_reduce(
thread_num_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.communicator.device_group,
)
thread_num = thread_num_tensor.item()
handle = torch.ops._C.init_shm_manager(
self.group_name,
self.communicator.world_size,
self.communicator.rank,
thread_num,
)
torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager(