[CPU Backend] [Perf] Accelerate tensor-parallel/data-parallel inference across NUMA domains on Arm (#32792)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
@@ -5,6 +5,10 @@
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#ifdef __aarch64__
|
||||
#include <atomic>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
@@ -34,8 +38,17 @@ struct KernelVecType<c10::Half> {
|
||||
};
|
||||
|
||||
struct ThreadSHMContext {
|
||||
#ifdef __aarch64__
|
||||
// memory model is weaker on AArch64, so we use atomic variables for
|
||||
// consumer (load-acquire) and producer (store-release) to make sure
|
||||
// that a stamp cannot be ready before the corresponding data is ready.
|
||||
std::atomic<char> _curr_thread_stamp[2];
|
||||
std::atomic<char> _ready_thread_stamp[2];
|
||||
static_assert(std::atomic<char>::is_always_lock_free);
|
||||
#else
|
||||
volatile char _curr_thread_stamp[2];
|
||||
volatile char _ready_thread_stamp[2];
|
||||
#endif // __aarch64__
|
||||
int local_stamp_buffer_idx;
|
||||
int remote_stamp_buffer_idx;
|
||||
int thread_id;
|
||||
@@ -62,10 +75,17 @@ struct ThreadSHMContext {
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||
#ifdef __aarch64__
|
||||
_curr_thread_stamp[0].store(1, std::memory_order_relaxed);
|
||||
_curr_thread_stamp[1].store(1, std::memory_order_relaxed);
|
||||
_ready_thread_stamp[0].store(0, std::memory_order_relaxed);
|
||||
_ready_thread_stamp[1].store(0, std::memory_order_relaxed);
|
||||
#else
|
||||
_curr_thread_stamp[0] = 1;
|
||||
_curr_thread_stamp[1] = 1;
|
||||
_ready_thread_stamp[0] = 0;
|
||||
_ready_thread_stamp[1] = 0;
|
||||
#endif // __aarch64__
|
||||
_thread_buffer_mask[0] = 0;
|
||||
_thread_buffer_mask[1] = 0;
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
@@ -103,19 +123,43 @@ struct ThreadSHMContext {
|
||||
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
|
||||
char get_curr_stamp(int idx) const {
|
||||
#ifdef __aarch64__
|
||||
return _curr_thread_stamp[idx].load(std::memory_order_acquire);
|
||||
#else
|
||||
return _curr_thread_stamp[idx];
|
||||
#endif // __aarch64__
|
||||
}
|
||||
|
||||
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
|
||||
char get_ready_stamp(int idx) const {
|
||||
#ifdef __aarch64__
|
||||
return _ready_thread_stamp[idx].load(std::memory_order_acquire);
|
||||
#else
|
||||
return _ready_thread_stamp[idx];
|
||||
#endif // __aarch64__
|
||||
}
|
||||
|
||||
void next_stamp() {
|
||||
#ifdef __aarch64__
|
||||
_curr_thread_stamp[local_stamp_buffer_idx].fetch_add(
|
||||
1, std::memory_order_release);
|
||||
#else
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
|
||||
#endif // __aarch64__
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
#ifdef __aarch64__
|
||||
_ready_thread_stamp[local_stamp_buffer_idx].store(
|
||||
_curr_thread_stamp[local_stamp_buffer_idx].load(
|
||||
std::memory_order_relaxed),
|
||||
std::memory_order_release);
|
||||
#else
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp[local_stamp_buffer_idx] =
|
||||
_curr_thread_stamp[local_stamp_buffer_idx];
|
||||
#endif // __aarch64__
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
@@ -142,7 +186,11 @@ struct ThreadSHMContext {
|
||||
break;
|
||||
}
|
||||
++_spinning_count;
|
||||
#ifdef __aarch64__
|
||||
__asm__ __volatile__("yield");
|
||||
#else
|
||||
_mm_pause();
|
||||
#endif // __aarch64__
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user