[Feature][CPU Backend]: Optimize ARM vectorization backend (#30329)
Signed-off-by: Radu Salavat <radu.salavat@arm.com>
This commit is contained in:
@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
|
|||||||
using vec_t = vec_op::FP32Vec16;
|
using vec_t = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
|
|
||||||
// ARM only supports BF16 with ARMv8.6-A extension
|
|
||||||
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
|
|
||||||
#else
|
|
||||||
template <>
|
template <>
|
||||||
struct VecTypeTrait<c10::BFloat16> {
|
struct VecTypeTrait<c10::BFloat16> {
|
||||||
using vec_t = vec_op::BF16Vec16;
|
using vec_t = vec_op::BF16Vec16;
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
#if !defined(__powerpc__) && !defined(__s390x__)
|
#if !defined(__powerpc__) && !defined(__s390x__)
|
||||||
template <>
|
template <>
|
||||||
@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
|
|||||||
|
|
||||||
if (use_sink) {
|
if (use_sink) {
|
||||||
alignas(64) float s_aux_fp32[16];
|
alignas(64) float s_aux_fp32[16];
|
||||||
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
|
||||||
// ARM without native BF16 support: manual conversion
|
|
||||||
for (int i = 0; i < 16; ++i) {
|
|
||||||
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
// All other platforms have BF16Vec16 available
|
// All other platforms have BF16Vec16 available
|
||||||
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
|
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
|
||||||
vec_op::FP32Vec16 vec_fp32(vec_bf16);
|
vec_op::FP32Vec16 vec_fp32(vec_bf16);
|
||||||
vec_fp32.save(s_aux_fp32);
|
vec_fp32.save(s_aux_fp32);
|
||||||
#endif
|
|
||||||
|
|
||||||
float* __restrict__ curr_sum_buffer = sum_buffer;
|
float* __restrict__ curr_sum_buffer = sum_buffer;
|
||||||
float* __restrict__ curr_max_buffer = max_buffer;
|
float* __restrict__ curr_max_buffer = max_buffer;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -14,13 +14,11 @@ struct KernelVecType<float> {
|
|||||||
using cvt_vec_type = vec_op::FP32Vec16;
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
|
||||||
template <>
|
template <>
|
||||||
struct KernelVecType<c10::BFloat16> {
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using load_vec_type = vec_op::BF16Vec16;
|
using load_vec_type = vec_op::BF16Vec16;
|
||||||
using cvt_vec_type = vec_op::FP32Vec16;
|
using cvt_vec_type = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct KernelVecType<c10::Half> {
|
struct KernelVecType<c10::Half> {
|
||||||
|
|||||||
@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
|
|||||||
using qk_vec_type = vec_op::BF16Vec32;
|
using qk_vec_type = vec_op::BF16Vec32;
|
||||||
using v_load_vec_type = vec_op::BF16Vec16;
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
};
|
};
|
||||||
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
#elif defined(__aarch64__)
|
||||||
// pass
|
|
||||||
#else
|
|
||||||
template <>
|
template <>
|
||||||
struct KernelVecType<c10::BFloat16> {
|
struct KernelVecType<c10::BFloat16> {
|
||||||
using qk_load_vec_type = vec_op::BF16Vec16;
|
using qk_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
|||||||
@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
|
|||||||
using vec_t = vec_op::FP32Vec16;
|
using vec_t = vec_op::FP32Vec16;
|
||||||
};
|
};
|
||||||
|
|
||||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
|
||||||
template <>
|
template <>
|
||||||
struct VecTypeTrait<c10::BFloat16> {
|
struct VecTypeTrait<c10::BFloat16> {
|
||||||
using vec_t = vec_op::BF16Vec16;
|
using vec_t = vec_op::BF16Vec16;
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
#if !defined(__powerpc__)
|
#if !defined(__powerpc__)
|
||||||
template <>
|
template <>
|
||||||
|
|||||||
Reference in New Issue
Block a user