[Bugfix][CPU] Fix thread num for shared memory communication (#33317)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Li, Jiang <bigpyj64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Li, Jiang
2026-01-29 19:26:58 +08:00
committed by GitHub
parent 40c35038d2
commit 8311f083bd
3 changed files with 25 additions and 10 deletions

View File

@@ -35,7 +35,7 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank);
const int64_t rank, const int64_t thread_num);
std::string join_shm_manager(int64_t handle, const std::string& name);
@@ -232,8 +232,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// SHM CCL
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
&init_shm_manager);
ops.def(
"init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
"int",
&init_shm_manager);
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);