[Feature][CPU Backend]: Optimize ARM vectorization backend (#30329)
Signed-off-by: Radu Salavat <radu.salavat@arm.com>
This commit is contained in:
@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
|
||||
using vec_t = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
// ARM only supports BF16 with ARMv8.6-A extension
|
||||
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
|
||||
#else
|
||||
template <>
|
||||
struct VecTypeTrait<c10::BFloat16> {
|
||||
using vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if !defined(__powerpc__) && !defined(__s390x__)
|
||||
template <>
|
||||
@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
|
||||
|
||||
if (use_sink) {
|
||||
alignas(64) float s_aux_fp32[16];
|
||||
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
||||
// ARM without native BF16 support: manual conversion
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
|
||||
}
|
||||
#else
|
||||
// All other platforms have BF16Vec16 available
|
||||
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
|
||||
vec_op::FP32Vec16 vec_fp32(vec_bf16);
|
||||
vec_fp32.save(s_aux_fp32);
|
||||
#endif
|
||||
|
||||
float* __restrict__ curr_sum_buffer = sum_buffer;
|
||||
float* __restrict__ curr_max_buffer = max_buffer;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,13 +14,11 @@ struct KernelVecType<float> {
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
|
||||
@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
|
||||
using qk_vec_type = vec_op::BF16Vec32;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
||||
// pass
|
||||
#else
|
||||
#elif defined(__aarch64__)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec16;
|
||||
|
||||
@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
|
||||
using vec_t = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct VecTypeTrait<c10::BFloat16> {
|
||||
using vec_t = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if !defined(__powerpc__)
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user