[CPU Backend] [Perf] Accelerate tensor-parallel/data-parallel inference across NUMA domains on Arm (#32792)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
Fadi Arafeh
2026-01-22 18:55:23 +00:00
committed by GitHub
parent 300622e609
commit 744ef30484
6 changed files with 164 additions and 6 deletions

View File

@@ -5,6 +5,10 @@
#include <sys/stat.h>
#include <unistd.h>
#ifdef __aarch64__
#include <atomic>
#endif
namespace {
#define MAX_SHM_RANK_NUM 8
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
@@ -34,8 +38,17 @@ struct KernelVecType<c10::Half> {
};
struct ThreadSHMContext {
#ifdef __aarch64__
// memory model is weaker on AArch64, so we use atomic variables for
// consumer (load-acquire) and producer (store-release) to make sure
// that a stamp cannot be ready before the corresponding data is ready.
std::atomic<char> _curr_thread_stamp[2];
std::atomic<char> _ready_thread_stamp[2];
static_assert(std::atomic<char>::is_always_lock_free);
#else
volatile char _curr_thread_stamp[2];
volatile char _ready_thread_stamp[2];
#endif // __aarch64__
int local_stamp_buffer_idx;
int remote_stamp_buffer_idx;
int thread_id;
@@ -62,10 +75,17 @@ struct ThreadSHMContext {
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
TORCH_CHECK((size_t)this % 64 == 0);
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
#ifdef __aarch64__
_curr_thread_stamp[0].store(1, std::memory_order_relaxed);
_curr_thread_stamp[1].store(1, std::memory_order_relaxed);
_ready_thread_stamp[0].store(0, std::memory_order_relaxed);
_ready_thread_stamp[1].store(0, std::memory_order_relaxed);
#else
_curr_thread_stamp[0] = 1;
_curr_thread_stamp[1] = 1;
_ready_thread_stamp[0] = 0;
_ready_thread_stamp[1] = 0;
#endif // __aarch64__
_thread_buffer_mask[0] = 0;
_thread_buffer_mask[1] = 0;
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
@@ -103,19 +123,43 @@ struct ThreadSHMContext {
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
}
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
char get_curr_stamp(int idx) const {
#ifdef __aarch64__
return _curr_thread_stamp[idx].load(std::memory_order_acquire);
#else
return _curr_thread_stamp[idx];
#endif // __aarch64__
}
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
char get_ready_stamp(int idx) const {
#ifdef __aarch64__
return _ready_thread_stamp[idx].load(std::memory_order_acquire);
#else
return _ready_thread_stamp[idx];
#endif // __aarch64__
}
void next_stamp() {
#ifdef __aarch64__
_curr_thread_stamp[local_stamp_buffer_idx].fetch_add(
1, std::memory_order_release);
#else
_mm_mfence();
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
#endif // __aarch64__
}
void commit_ready_stamp() {
#ifdef __aarch64__
_ready_thread_stamp[local_stamp_buffer_idx].store(
_curr_thread_stamp[local_stamp_buffer_idx].load(
std::memory_order_relaxed),
std::memory_order_release);
#else
_mm_mfence();
_ready_thread_stamp[local_stamp_buffer_idx] =
_curr_thread_stamp[local_stamp_buffer_idx];
#endif // __aarch64__
}
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
@@ -142,7 +186,11 @@ struct ThreadSHMContext {
break;
}
++_spinning_count;
#ifdef __aarch64__
__asm__ __volatile__("yield");
#else
_mm_pause();
#endif // __aarch64__
}
}