diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 0af87fd7f..a50731d70 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -379,6 +379,12 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) endif() endif() +if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/shm.cpp" + ${VLLM_EXT_SRC}) +endif() + if(USE_ONEDNN) set(VLLM_EXT_SRC "csrc/cpu/dnnl_kernels.cpp" diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index 2251aac45..520c873df 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -80,8 +80,10 @@ struct FP16Vec16 : public Vec { reg.val[1] = vld1q_f16(reinterpret_cast(ptr) + 8); } - explicit FP16Vec16(const FP32Vec16& vec); + // ASIMD does not support non-temporal loads + explicit FP16Vec16(bool, const void* ptr) : FP16Vec16(ptr) {} + explicit FP16Vec16(const FP32Vec16& vec); void save(void* ptr) const { vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]); @@ -190,6 +192,9 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const void* ptr) : reg(*reinterpret_cast(ptr)) {}; + // ASIMD does not support non-temporal loads + explicit BF16Vec16(bool, const void* ptr) : BF16Vec16(ptr) {} + explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; explicit BF16Vec16(const FP32Vec16&); @@ -474,6 +479,9 @@ struct FP32Vec16 : public Vec { : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {} + // ASIMD does not support non-temporal loads + explicit FP32Vec16(bool, const float* ptr) : FP32Vec16(ptr) {} + explicit FP32Vec16(float32x4x4_t data) : reg(data) {} explicit FP32Vec16(const FP32Vec8& data) { @@ -756,6 +764,96 @@ struct INT8Vec16 : public Vec { }; }; +struct INT8Vec64 : public Vec { + constexpr static int VEC_ELEM_NUM = 64; + union AliasReg { + int8x16x4_t reg; + int8_t values[VEC_ELEM_NUM]; + }; + int8x16x4_t reg; + + explicit INT8Vec64(const int8_t* ptr) { reg = vld1q_s8_x4(ptr); } + + // ASIMD does not support non-temporal loads + explicit INT8Vec64(bool, const int8_t* ptr) : INT8Vec64(ptr) {} + + void save(int8_t* ptr) const { vst1q_s8_x4(ptr, reg); } + + // masked store + void save(int8_t* p, int elem_num) const { + TORCH_CHECK(elem_num <= VEC_ELEM_NUM && elem_num > 0); + + if (elem_num == VEC_ELEM_NUM) { + vst1q_s8_x4(p, reg); + return; + } + + const int full_quadwords = elem_num / 16; + const int remaining_bytes = elem_num % 16; + + for (int i = 0; i < full_quadwords; ++i) { + vst1q_s8(p + 16 * i, reg.val[i]); + } + + if (remaining_bytes) { + const int8x16_t v = reg.val[full_quadwords]; + int8_t* tail = p + 16 * full_quadwords; + switch (remaining_bytes) { + case 15: + tail[14] = vgetq_lane_s8(v, 14); + [[fallthrough]]; + case 14: + tail[13] = vgetq_lane_s8(v, 13); + [[fallthrough]]; + case 13: + tail[12] = vgetq_lane_s8(v, 12); + [[fallthrough]]; + case 12: + tail[11] = vgetq_lane_s8(v, 11); + [[fallthrough]]; + case 11: + tail[10] = vgetq_lane_s8(v, 10); + [[fallthrough]]; + case 10: + tail[9] = vgetq_lane_s8(v, 9); + [[fallthrough]]; + case 9: + tail[8] = vgetq_lane_s8(v, 8); + [[fallthrough]]; + case 8: + tail[7] = vgetq_lane_s8(v, 7); + [[fallthrough]]; + case 7: + tail[6] = vgetq_lane_s8(v, 6); + [[fallthrough]]; + case 6: + tail[5] = vgetq_lane_s8(v, 5); + [[fallthrough]]; + case 5: + tail[4] = vgetq_lane_s8(v, 4); + [[fallthrough]]; + case 4: + tail[3] = vgetq_lane_s8(v, 3); + [[fallthrough]]; + case 3: + tail[2] = vgetq_lane_s8(v, 2); + [[fallthrough]]; + case 2: + tail[1] = vgetq_lane_s8(v, 1); + [[fallthrough]]; + case 1: + tail[0] = vgetq_lane_s8(v, 0); + break; + default: + break; + } + } + } + + // ASIMD does not support non-temporal stores + void nt_save(int8_t* ptr) const { save(ptr); } +}; // INT8Vec64 + template struct VecType { using vec_type = void; diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index e43aa2037..01ca50109 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -5,6 +5,10 @@ #include #include +#ifdef __aarch64__ + #include +#endif + namespace { #define MAX_SHM_RANK_NUM 8 #define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) @@ -34,8 +38,17 @@ struct KernelVecType { }; struct ThreadSHMContext { +#ifdef __aarch64__ + // memory model is weaker on AArch64, so we use atomic variables for + // consumer (load-acquire) and producer (store-release) to make sure + // that a stamp cannot be ready before the corresponding data is ready. + std::atomic _curr_thread_stamp[2]; + std::atomic _ready_thread_stamp[2]; + static_assert(std::atomic::is_always_lock_free); +#else volatile char _curr_thread_stamp[2]; volatile char _ready_thread_stamp[2]; +#endif // __aarch64__ int local_stamp_buffer_idx; int remote_stamp_buffer_idx; int thread_id; @@ -62,10 +75,17 @@ struct ThreadSHMContext { TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); TORCH_CHECK((size_t)this % 64 == 0); TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0); +#ifdef __aarch64__ + _curr_thread_stamp[0].store(1, std::memory_order_relaxed); + _curr_thread_stamp[1].store(1, std::memory_order_relaxed); + _ready_thread_stamp[0].store(0, std::memory_order_relaxed); + _ready_thread_stamp[1].store(0, std::memory_order_relaxed); +#else _curr_thread_stamp[0] = 1; _curr_thread_stamp[1] = 1; _ready_thread_stamp[0] = 0; _ready_thread_stamp[1] = 0; +#endif // __aarch64__ _thread_buffer_mask[0] = 0; _thread_buffer_mask[1] = 0; for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) { @@ -103,19 +123,43 @@ struct ThreadSHMContext { _thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF; } - char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; } + char get_curr_stamp(int idx) const { +#ifdef __aarch64__ + return _curr_thread_stamp[idx].load(std::memory_order_acquire); +#else + return _curr_thread_stamp[idx]; +#endif // __aarch64__ + } - char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; } + char get_ready_stamp(int idx) const { +#ifdef __aarch64__ + return _ready_thread_stamp[idx].load(std::memory_order_acquire); +#else + return _ready_thread_stamp[idx]; +#endif // __aarch64__ + } void next_stamp() { +#ifdef __aarch64__ + _curr_thread_stamp[local_stamp_buffer_idx].fetch_add( + 1, std::memory_order_release); +#else _mm_mfence(); _curr_thread_stamp[local_stamp_buffer_idx] += 1; +#endif // __aarch64__ } void commit_ready_stamp() { +#ifdef __aarch64__ + _ready_thread_stamp[local_stamp_buffer_idx].store( + _curr_thread_stamp[local_stamp_buffer_idx].load( + std::memory_order_relaxed), + std::memory_order_release); +#else _mm_mfence(); _ready_thread_stamp[local_stamp_buffer_idx] = _curr_thread_stamp[local_stamp_buffer_idx]; +#endif // __aarch64__ } int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } @@ -142,7 +186,11 @@ struct ThreadSHMContext { break; } ++_spinning_count; +#ifdef __aarch64__ + __asm__ __volatile__("yield"); +#else _mm_pause(); +#endif // __aarch64__ } } diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index dd419405c..b93a08c01 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -230,7 +230,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { #endif // SHM CCL -#ifdef __AVX512F__ +#if defined(__AVX512F__) || defined(__aarch64__) ops.def("init_shm_manager(str name, int group_size, int rank) -> int", &init_shm_manager); ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager); @@ -250,7 +250,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list); ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", &shm_recv_tensor_list); -#endif +#endif // #if defined(__AVX512F__) || defined(__aarch64__) // sgl-kernels #if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 2a84418c1..25db6a2eb 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -29,7 +29,10 @@ class CpuCommunicator(DeviceCommunicatorBase): self.dist_module = torch.distributed if ( - (current_platform.get_cpu_architecture() == CpuArchEnum.X86) + ( + current_platform.get_cpu_architecture() == CpuArchEnum.X86 + or current_platform.get_cpu_architecture() == CpuArchEnum.ARM + ) and hasattr(torch.ops._C, "init_shm_manager") and (unique_name.startswith("tp") or unique_name.startswith("pp")) ): diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 654f58834..b8b6f266d 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -66,6 +66,9 @@ class CPUWorker(Worker): self.local_omp_cpuid = self._get_autobind_cpu_ids( lambda cpus: cpus[-1:] ) + elif cpu_arch == CpuArchEnum.ARM: + # For AArch64, no SMT + self.local_omp_cpuid = self._get_autobind_cpu_ids(lambda cpus: cpus) else: self.local_omp_cpuid = "nobind" elif omp_cpuids == "nobind":