[Bugfix][CPU] Fix thread num for shared memory communication (#33317)
Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Li, Jiang <bigpyj64@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -205,10 +205,22 @@ class _CPUSHMDistributed:
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
thread_num_tensor = torch.tensor(
|
||||
[torch.get_num_threads()],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
thread_num_tensor,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.communicator.device_group,
|
||||
)
|
||||
thread_num = thread_num_tensor.item()
|
||||
|
||||
handle = torch.ops._C.init_shm_manager(
|
||||
self.group_name,
|
||||
self.communicator.world_size,
|
||||
self.communicator.rank,
|
||||
thread_num,
|
||||
)
|
||||
torch.distributed.barrier(self.communicator.device_group)
|
||||
torch.ops._C.join_shm_manager(
|
||||
|
||||
Reference in New Issue
Block a user