[CPU] Add head sizes 80 and 112 with vec16 fallback (#31968)
Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
#ifdef __aarch64__
|
||||
#include "cpu_attn_neon.hpp"
|
||||
// NEON requires head_dim to be a multiple of 32
|
||||
#define NEON_DISPATCH(...) \
|
||||
case cpu_attention::ISA::NEON: { \
|
||||
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
|
||||
@@ -36,7 +37,9 @@
|
||||
switch (HEAD_DIM) { \
|
||||
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
|
||||
|
||||
@@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, const float scale) {
|
||||
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
|
||||
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
|
||||
// static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
|
||||
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
|
||||
constexpr int64_t head_elem_num_pre_block =
|
||||
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
|
||||
|
||||
@@ -264,7 +264,7 @@ class AttentionImpl<ISA::NEON, scalar_t, head_dim> {
|
||||
constexpr static ISA ISAType = ISA::NEON;
|
||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||
|
||||
static_assert(HeadDim % HeadDimAlignment == 0);
|
||||
// static_assert(HeadDim % HeadDimAlignment == 0);
|
||||
// the gemm micro kernel is Mx8
|
||||
static_assert(HeadDimAlignment % 8 == 0);
|
||||
static_assert(BlockSizeAlignment % 8 == 0);
|
||||
|
||||
Reference in New Issue
Block a user