[perf][cpu] Accelerate paged attention GEMMs (QK, PV) on Arm CPUs with NEON (#29193)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
This commit is contained in:
@@ -13,6 +13,18 @@
|
||||
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
|
||||
#endif
|
||||
|
||||
#ifdef __aarch64__
|
||||
#include "cpu_attn_neon.hpp"
|
||||
#define NEON_DISPATCH(...) \
|
||||
case cpu_attention::ISA::NEON: { \
|
||||
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
|
||||
scalar_t, head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
|
||||
#endif // #ifdef __aarch64__
|
||||
|
||||
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
|
||||
case HEAD_DIM: { \
|
||||
constexpr size_t head_dim = HEAD_DIM; \
|
||||
@@ -41,6 +53,7 @@
|
||||
[&] { \
|
||||
switch (ISA_TYPE) { \
|
||||
AMX_DISPATCH(__VA_ARGS__) \
|
||||
NEON_DISPATCH(__VA_ARGS__) \
|
||||
case cpu_attention::ISA::VEC: { \
|
||||
using attn_impl = \
|
||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
|
||||
@@ -73,6 +86,8 @@ torch::Tensor get_scheduler_metadata(
|
||||
isa = cpu_attention::ISA::VEC;
|
||||
} else if (isa_hint == "vec16") {
|
||||
isa = cpu_attention::ISA::VEC16;
|
||||
} else if (isa_hint == "neon") {
|
||||
isa = cpu_attention::ISA::NEON;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
|
||||
}
|
||||
@@ -158,6 +173,8 @@ void cpu_attn_reshape_and_cache(
|
||||
return cpu_attention::ISA::VEC;
|
||||
} else if (isa == "vec16") {
|
||||
return cpu_attention::ISA::VEC16;
|
||||
} else if (isa == "neon") {
|
||||
return cpu_attention::ISA::NEON;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid ISA type: " + isa);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user