[perf][cpu] Accelerate paged attention GEMMs (QK, PV) on Arm CPUs with NEON (#29193)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
Fadi Arafeh
2025-11-22 17:04:36 +00:00
committed by GitHub
parent f55c76c2b3
commit 730bd35378
5 changed files with 416 additions and 5 deletions

View File

@@ -13,6 +13,18 @@
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
#define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
#endif // #ifdef __aarch64__
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
@@ -41,6 +53,7 @@
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
@@ -73,6 +86,8 @@ torch::Tensor get_scheduler_metadata(
isa = cpu_attention::ISA::VEC;
} else if (isa_hint == "vec16") {
isa = cpu_attention::ISA::VEC16;
} else if (isa_hint == "neon") {
isa = cpu_attention::ISA::NEON;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
@@ -158,6 +173,8 @@ void cpu_attn_reshape_and_cache(
return cpu_attention::ISA::VEC;
} else if (isa == "vec16") {
return cpu_attention::ISA::VEC16;
} else if (isa == "neon") {
return cpu_attention::ISA::NEON;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}