[CPU Backend] [Perf] Accelerate tensor-parallel/data-parallel inference across NUMA domains on Arm (#32792)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
@@ -379,6 +379,12 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
|
||||||
|
set(VLLM_EXT_SRC
|
||||||
|
"csrc/cpu/shm.cpp"
|
||||||
|
${VLLM_EXT_SRC})
|
||||||
|
endif()
|
||||||
|
|
||||||
if(USE_ONEDNN)
|
if(USE_ONEDNN)
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/cpu/dnnl_kernels.cpp"
|
"csrc/cpu/dnnl_kernels.cpp"
|
||||||
|
|||||||
@@ -80,8 +80,10 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit FP16Vec16(const FP32Vec16& vec);
|
// ASIMD does not support non-temporal loads
|
||||||
|
explicit FP16Vec16(bool, const void* ptr) : FP16Vec16(ptr) {}
|
||||||
|
|
||||||
|
explicit FP16Vec16(const FP32Vec16& vec);
|
||||||
void save(void* ptr) const {
|
void save(void* ptr) const {
|
||||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||||
@@ -190,6 +192,9 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
explicit BF16Vec16(const void* ptr)
|
explicit BF16Vec16(const void* ptr)
|
||||||
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
|
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
|
||||||
|
|
||||||
|
// ASIMD does not support non-temporal loads
|
||||||
|
explicit BF16Vec16(bool, const void* ptr) : BF16Vec16(ptr) {}
|
||||||
|
|
||||||
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
||||||
|
|
||||||
explicit BF16Vec16(const FP32Vec16&);
|
explicit BF16Vec16(const FP32Vec16&);
|
||||||
@@ -474,6 +479,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
|
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
|
||||||
vld1q_f32(ptr + 12)}) {}
|
vld1q_f32(ptr + 12)}) {}
|
||||||
|
|
||||||
|
// ASIMD does not support non-temporal loads
|
||||||
|
explicit FP32Vec16(bool, const float* ptr) : FP32Vec16(ptr) {}
|
||||||
|
|
||||||
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec8& data) {
|
explicit FP32Vec16(const FP32Vec8& data) {
|
||||||
@@ -756,6 +764,96 @@ struct INT8Vec16 : public Vec<INT8Vec16> {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct INT8Vec64 : public Vec<INT8Vec64> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 64;
|
||||||
|
union AliasReg {
|
||||||
|
int8x16x4_t reg;
|
||||||
|
int8_t values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
int8x16x4_t reg;
|
||||||
|
|
||||||
|
explicit INT8Vec64(const int8_t* ptr) { reg = vld1q_s8_x4(ptr); }
|
||||||
|
|
||||||
|
// ASIMD does not support non-temporal loads
|
||||||
|
explicit INT8Vec64(bool, const int8_t* ptr) : INT8Vec64(ptr) {}
|
||||||
|
|
||||||
|
void save(int8_t* ptr) const { vst1q_s8_x4(ptr, reg); }
|
||||||
|
|
||||||
|
// masked store
|
||||||
|
void save(int8_t* p, int elem_num) const {
|
||||||
|
TORCH_CHECK(elem_num <= VEC_ELEM_NUM && elem_num > 0);
|
||||||
|
|
||||||
|
if (elem_num == VEC_ELEM_NUM) {
|
||||||
|
vst1q_s8_x4(p, reg);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int full_quadwords = elem_num / 16;
|
||||||
|
const int remaining_bytes = elem_num % 16;
|
||||||
|
|
||||||
|
for (int i = 0; i < full_quadwords; ++i) {
|
||||||
|
vst1q_s8(p + 16 * i, reg.val[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining_bytes) {
|
||||||
|
const int8x16_t v = reg.val[full_quadwords];
|
||||||
|
int8_t* tail = p + 16 * full_quadwords;
|
||||||
|
switch (remaining_bytes) {
|
||||||
|
case 15:
|
||||||
|
tail[14] = vgetq_lane_s8(v, 14);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 14:
|
||||||
|
tail[13] = vgetq_lane_s8(v, 13);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 13:
|
||||||
|
tail[12] = vgetq_lane_s8(v, 12);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 12:
|
||||||
|
tail[11] = vgetq_lane_s8(v, 11);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 11:
|
||||||
|
tail[10] = vgetq_lane_s8(v, 10);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 10:
|
||||||
|
tail[9] = vgetq_lane_s8(v, 9);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 9:
|
||||||
|
tail[8] = vgetq_lane_s8(v, 8);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 8:
|
||||||
|
tail[7] = vgetq_lane_s8(v, 7);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 7:
|
||||||
|
tail[6] = vgetq_lane_s8(v, 6);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 6:
|
||||||
|
tail[5] = vgetq_lane_s8(v, 5);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 5:
|
||||||
|
tail[4] = vgetq_lane_s8(v, 4);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 4:
|
||||||
|
tail[3] = vgetq_lane_s8(v, 3);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 3:
|
||||||
|
tail[2] = vgetq_lane_s8(v, 2);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 2:
|
||||||
|
tail[1] = vgetq_lane_s8(v, 1);
|
||||||
|
[[fallthrough]];
|
||||||
|
case 1:
|
||||||
|
tail[0] = vgetq_lane_s8(v, 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASIMD does not support non-temporal stores
|
||||||
|
void nt_save(int8_t* ptr) const { save(ptr); }
|
||||||
|
}; // INT8Vec64
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct VecType {
|
struct VecType {
|
||||||
using vec_type = void;
|
using vec_type = void;
|
||||||
|
|||||||
@@ -5,6 +5,10 @@
|
|||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
|
#ifdef __aarch64__
|
||||||
|
#include <atomic>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
#define MAX_SHM_RANK_NUM 8
|
#define MAX_SHM_RANK_NUM 8
|
||||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||||
@@ -34,8 +38,17 @@ struct KernelVecType<c10::Half> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ThreadSHMContext {
|
struct ThreadSHMContext {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
// memory model is weaker on AArch64, so we use atomic variables for
|
||||||
|
// consumer (load-acquire) and producer (store-release) to make sure
|
||||||
|
// that a stamp cannot be ready before the corresponding data is ready.
|
||||||
|
std::atomic<char> _curr_thread_stamp[2];
|
||||||
|
std::atomic<char> _ready_thread_stamp[2];
|
||||||
|
static_assert(std::atomic<char>::is_always_lock_free);
|
||||||
|
#else
|
||||||
volatile char _curr_thread_stamp[2];
|
volatile char _curr_thread_stamp[2];
|
||||||
volatile char _ready_thread_stamp[2];
|
volatile char _ready_thread_stamp[2];
|
||||||
|
#endif // __aarch64__
|
||||||
int local_stamp_buffer_idx;
|
int local_stamp_buffer_idx;
|
||||||
int remote_stamp_buffer_idx;
|
int remote_stamp_buffer_idx;
|
||||||
int thread_id;
|
int thread_id;
|
||||||
@@ -62,10 +75,17 @@ struct ThreadSHMContext {
|
|||||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||||
TORCH_CHECK((size_t)this % 64 == 0);
|
TORCH_CHECK((size_t)this % 64 == 0);
|
||||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||||
|
#ifdef __aarch64__
|
||||||
|
_curr_thread_stamp[0].store(1, std::memory_order_relaxed);
|
||||||
|
_curr_thread_stamp[1].store(1, std::memory_order_relaxed);
|
||||||
|
_ready_thread_stamp[0].store(0, std::memory_order_relaxed);
|
||||||
|
_ready_thread_stamp[1].store(0, std::memory_order_relaxed);
|
||||||
|
#else
|
||||||
_curr_thread_stamp[0] = 1;
|
_curr_thread_stamp[0] = 1;
|
||||||
_curr_thread_stamp[1] = 1;
|
_curr_thread_stamp[1] = 1;
|
||||||
_ready_thread_stamp[0] = 0;
|
_ready_thread_stamp[0] = 0;
|
||||||
_ready_thread_stamp[1] = 0;
|
_ready_thread_stamp[1] = 0;
|
||||||
|
#endif // __aarch64__
|
||||||
_thread_buffer_mask[0] = 0;
|
_thread_buffer_mask[0] = 0;
|
||||||
_thread_buffer_mask[1] = 0;
|
_thread_buffer_mask[1] = 0;
|
||||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||||
@@ -103,19 +123,43 @@ struct ThreadSHMContext {
|
|||||||
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
|
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
|
||||||
}
|
}
|
||||||
|
|
||||||
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
|
char get_curr_stamp(int idx) const {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
return _curr_thread_stamp[idx].load(std::memory_order_acquire);
|
||||||
|
#else
|
||||||
|
return _curr_thread_stamp[idx];
|
||||||
|
#endif // __aarch64__
|
||||||
|
}
|
||||||
|
|
||||||
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
|
char get_ready_stamp(int idx) const {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
return _ready_thread_stamp[idx].load(std::memory_order_acquire);
|
||||||
|
#else
|
||||||
|
return _ready_thread_stamp[idx];
|
||||||
|
#endif // __aarch64__
|
||||||
|
}
|
||||||
|
|
||||||
void next_stamp() {
|
void next_stamp() {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
_curr_thread_stamp[local_stamp_buffer_idx].fetch_add(
|
||||||
|
1, std::memory_order_release);
|
||||||
|
#else
|
||||||
_mm_mfence();
|
_mm_mfence();
|
||||||
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
|
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
|
||||||
|
#endif // __aarch64__
|
||||||
}
|
}
|
||||||
|
|
||||||
void commit_ready_stamp() {
|
void commit_ready_stamp() {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
_ready_thread_stamp[local_stamp_buffer_idx].store(
|
||||||
|
_curr_thread_stamp[local_stamp_buffer_idx].load(
|
||||||
|
std::memory_order_relaxed),
|
||||||
|
std::memory_order_release);
|
||||||
|
#else
|
||||||
_mm_mfence();
|
_mm_mfence();
|
||||||
_ready_thread_stamp[local_stamp_buffer_idx] =
|
_ready_thread_stamp[local_stamp_buffer_idx] =
|
||||||
_curr_thread_stamp[local_stamp_buffer_idx];
|
_curr_thread_stamp[local_stamp_buffer_idx];
|
||||||
|
#endif // __aarch64__
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||||
@@ -142,7 +186,11 @@ struct ThreadSHMContext {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
++_spinning_count;
|
++_spinning_count;
|
||||||
|
#ifdef __aarch64__
|
||||||
|
__asm__ __volatile__("yield");
|
||||||
|
#else
|
||||||
_mm_pause();
|
_mm_pause();
|
||||||
|
#endif // __aarch64__
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// SHM CCL
|
// SHM CCL
|
||||||
#ifdef __AVX512F__
|
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||||
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
||||||
&init_shm_manager);
|
&init_shm_manager);
|
||||||
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
||||||
@@ -250,7 +250,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
|
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
|
||||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||||
&shm_recv_tensor_list);
|
&shm_recv_tensor_list);
|
||||||
#endif
|
#endif // #if defined(__AVX512F__) || defined(__aarch64__)
|
||||||
|
|
||||||
// sgl-kernels
|
// sgl-kernels
|
||||||
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
|||||||
self.dist_module = torch.distributed
|
self.dist_module = torch.distributed
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(current_platform.get_cpu_architecture() == CpuArchEnum.X86)
|
(
|
||||||
|
current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||||
|
or current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||||
|
)
|
||||||
and hasattr(torch.ops._C, "init_shm_manager")
|
and hasattr(torch.ops._C, "init_shm_manager")
|
||||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ class CPUWorker(Worker):
|
|||||||
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
self.local_omp_cpuid = self._get_autobind_cpu_ids(
|
||||||
lambda cpus: cpus[-1:]
|
lambda cpus: cpus[-1:]
|
||||||
)
|
)
|
||||||
|
elif cpu_arch == CpuArchEnum.ARM:
|
||||||
|
# For AArch64, no SMT
|
||||||
|
self.local_omp_cpuid = self._get_autobind_cpu_ids(lambda cpus: cpus)
|
||||||
else:
|
else:
|
||||||
self.local_omp_cpuid = "nobind"
|
self.local_omp_cpuid = "nobind"
|
||||||
elif omp_cpuids == "nobind":
|
elif omp_cpuids == "nobind":
|
||||||
|
|||||||
Reference in New Issue
Block a user