[CPU] Split attention dispatch by head_dim alignment (#32161)

Signed-off-by: Rehan Khan <Rehan.Khan7@ibm.com>
This commit is contained in:
R3hankhan
2026-02-04 09:07:15 +05:30
committed by GitHub
parent e1bf04b6c2
commit 4dffc5e044
6 changed files with 241 additions and 107 deletions

View File

@@ -1,79 +1,4 @@
#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"
#ifdef CPU_CAPABILITY_AMXBF16
#include "cpu_attn_amx.hpp"
#define AMX_DISPATCH(...) \
case cpu_attention::ISA::AMX: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
#endif
#ifdef __aarch64__
#include "cpu_attn_neon.hpp"
// NEON requires head_dim to be a multiple of 32
#define NEON_DISPATCH(...) \
case cpu_attention::ISA::NEON: { \
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::NEON, \
scalar_t, head_dim>; \
return __VA_ARGS__(); \
}
#else
#define NEON_DISPATCH(...) case cpu_attention::ISA::NEON:
#endif // #ifdef __aarch64__
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
case HEAD_DIM: { \
constexpr size_t head_dim = HEAD_DIM; \
return __VA_ARGS__(); \
}
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
[&] { \
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
default: { \
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
std::to_string(HEAD_DIM)); \
} \
} \
}()
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
[&] { \
switch (ISA_TYPE) { \
AMX_DISPATCH(__VA_ARGS__) \
NEON_DISPATCH(__VA_ARGS__) \
case cpu_attention::ISA::VEC: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
case cpu_attention::ISA::VEC16: { \
using attn_impl = \
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
head_dim>; \
return __VA_ARGS__(); \
} \
default: { \
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
} \
} \
}()
#include "cpu_attn_dispatch_generated.h"
torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
@@ -122,16 +47,14 @@ torch::Tensor get_scheduler_metadata(
input.enable_kv_split = enable_kv_split;
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
input.output_buffer_elem_size =
sizeof(attn_impl::partial_output_buffer_t);
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
});
CPU_ATTN_DISPATCH(head_dim, isa, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
input.output_buffer_elem_size =
sizeof(attn_impl::partial_output_buffer_t);
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
});
});
@@ -184,18 +107,14 @@ void cpu_attn_reshape_and_cache(
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
attn_impl::reshape_and_cache(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num,
key_token_num_stride, value_token_num_stride, head_num,
key_head_num_stride, value_head_num_stride, num_blocks,
num_blocks_stride, cache_head_num_stride, block_size,
block_size_stride);
});
CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() {
attn_impl::reshape_and_cache(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num, key_token_num_stride,
value_token_num_stride, head_num, key_head_num_stride,
value_head_num_stride, num_blocks, num_blocks_stride,
cache_head_num_stride, block_size, block_size_stride);
});
});
}
@@ -257,12 +176,10 @@ void cpu_attention_with_kv_cache(
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
});
}