[Bugfix][CPU] Fix thread num for shared memory communication (#33317)
Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Li, Jiang <bigpyj64@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -35,7 +35,7 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
|
||||
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||
const int64_t rank);
|
||||
const int64_t rank, const int64_t thread_num);
|
||||
|
||||
std::string join_shm_manager(int64_t handle, const std::string& name);
|
||||
|
||||
@@ -232,8 +232,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// SHM CCL
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
|
||||
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
||||
&init_shm_manager);
|
||||
ops.def(
|
||||
"init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
|
||||
"int",
|
||||
&init_shm_manager);
|
||||
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
||||
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
|
||||
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
|
||||
Reference in New Issue
Block a user