diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index 01ca50109..a7fdd0c9d 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -237,10 +237,10 @@ struct ThreadSHMContext { class SHMManager { public: explicit SHMManager(const std::string& name, const int rank, - const int group_size) + const int group_size, const int thread_num) : _rank(rank), _group_size(group_size), - _thread_num(omp_get_max_threads()), + _thread_num(thread_num), _shm_names({""}), _shared_mem_ptrs({nullptr}), _shm_ctx(nullptr) { @@ -282,11 +282,11 @@ class SHMManager { } static int64_t create_singleton_instance(const std::string& name, - const int group_size, - const int rank) { + const int group_size, const int rank, + const int thread_num) { std::lock_guard guard(SingletonInstancesLock); SingletonInstances.emplace_back( - std::make_unique(name, rank, group_size)); + std::make_unique(name, rank, group_size, thread_num)); return static_cast(SingletonInstances.size() - 1); } @@ -854,8 +854,9 @@ std::vector shm_recv_tensor_list(int64_t handle, int64_t src) { } int64_t init_shm_manager(const std::string& name, const int64_t group_size, - const int64_t rank) { - return SHMManager::create_singleton_instance(name, group_size, rank); + const int64_t rank, const int64_t thread_num) { + return SHMManager::create_singleton_instance(name, group_size, rank, + thread_num); } std::string join_shm_manager(int64_t handle, const std::string& name) { diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 718ccc5e8..b54447b7d 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -35,7 +35,7 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& block_tables, torch::Tensor& seq_lens); int64_t init_shm_manager(const std::string& name, const int64_t group_size, - const int64_t rank); + const int64_t rank, const int64_t thread_num); std::string join_shm_manager(int64_t handle, const std::string& name); @@ -232,8 +232,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // SHM CCL #if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) - ops.def("init_shm_manager(str name, int group_size, int rank) -> int", - &init_shm_manager); + ops.def( + "init_shm_manager(str name, int group_size, int rank, int thread_num) -> " + "int", + &init_shm_manager); ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager); ops.def("shm_allreduce(int handle, Tensor! data) -> ()"); ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce); diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index b5fbdfcc3..23be8fcfc 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -205,10 +205,22 @@ class _CPUSHMDistributed: self.handle = self._init_cpu_shm() def _init_cpu_shm(self) -> int: + thread_num_tensor = torch.tensor( + [torch.get_num_threads()], + dtype=torch.int64, + ) + torch.distributed.all_reduce( + thread_num_tensor, + op=torch.distributed.ReduceOp.MIN, + group=self.communicator.device_group, + ) + thread_num = thread_num_tensor.item() + handle = torch.ops._C.init_shm_manager( self.group_name, self.communicator.world_size, self.communicator.rank, + thread_num, ) torch.distributed.barrier(self.communicator.device_group) torch.ops._C.join_shm_manager(