[CPU Backend] [Perf] Accelerate tensor-parallel/data-parallel inference across NUMA domains on Arm (#32792)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
@@ -230,7 +230,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
||||
&init_shm_manager);
|
||||
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
||||
@@ -250,7 +250,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
#endif // #if defined(__AVX512F__) || defined(__aarch64__)
|
||||
|
||||
// sgl-kernels
|
||||
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
||||
|
||||
Reference in New Issue
Block a user