[Feature][CPU Backend]: Optimize ARM vectorization backend (#30329)

Signed-off-by: Radu Salavat <radu.salavat@arm.com>
This commit is contained in:
Radu Salavat
2026-02-03 04:17:56 +00:00
committed by GitHub
parent 5eac9a1b34
commit e69c990c21
5 changed files with 579 additions and 624 deletions

View File

@@ -816,14 +816,10 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16; using vec_t = vec_op::FP32Vec16;
}; };
// ARM only supports BF16 with ARMv8.6-A extension
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
#else
template <> template <>
struct VecTypeTrait<c10::BFloat16> { struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16; using vec_t = vec_op::BF16Vec16;
}; };
#endif
#if !defined(__powerpc__) && !defined(__s390x__) #if !defined(__powerpc__) && !defined(__s390x__)
template <> template <>
@@ -1585,17 +1581,10 @@ class AttentionMainLoop {
if (use_sink) { if (use_sink) {
alignas(64) float s_aux_fp32[16]; alignas(64) float s_aux_fp32[16];
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// ARM without native BF16 support: manual conversion
for (int i = 0; i < 16; ++i) {
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
}
#else
// All other platforms have BF16Vec16 available // All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux); vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16); vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32); vec_fp32.save(s_aux_fp32);
#endif
float* __restrict__ curr_sum_buffer = sum_buffer; float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer; float* __restrict__ curr_max_buffer = max_buffer;

File diff suppressed because it is too large Load Diff

View File

@@ -14,13 +14,11 @@ struct KernelVecType<float> {
using cvt_vec_type = vec_op::FP32Vec16; using cvt_vec_type = vec_op::FP32Vec16;
}; };
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <> template <>
struct KernelVecType<c10::BFloat16> { struct KernelVecType<c10::BFloat16> {
using load_vec_type = vec_op::BF16Vec16; using load_vec_type = vec_op::BF16Vec16;
using cvt_vec_type = vec_op::FP32Vec16; using cvt_vec_type = vec_op::FP32Vec16;
}; };
#endif
template <> template <>
struct KernelVecType<c10::Half> { struct KernelVecType<c10::Half> {

View File

@@ -38,9 +38,7 @@ struct KernelVecType<c10::BFloat16> {
using qk_vec_type = vec_op::BF16Vec32; using qk_vec_type = vec_op::BF16Vec32;
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT) #elif defined(__aarch64__)
// pass
#else
template <> template <>
struct KernelVecType<c10::BFloat16> { struct KernelVecType<c10::BFloat16> {
using qk_load_vec_type = vec_op::BF16Vec16; using qk_load_vec_type = vec_op::BF16Vec16;

View File

@@ -30,12 +30,10 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16; using vec_t = vec_op::FP32Vec16;
}; };
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
template <> template <>
struct VecTypeTrait<c10::BFloat16> { struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16; using vec_t = vec_op::BF16Vec16;
}; };
#endif
#if !defined(__powerpc__) #if !defined(__powerpc__)
template <> template <>