2026-02-04 09:07:15 +05:30
|
|
|
#include "cpu_attn_dispatch_generated.h"
|
2025-11-12 09:43:06 +08:00
|
|
|
|
|
|
|
|
torch::Tensor get_scheduler_metadata(
|
|
|
|
|
const int64_t num_req, const int64_t num_heads_q,
|
|
|
|
|
const int64_t num_heads_kv, const int64_t head_dim,
|
|
|
|
|
const torch::Tensor& seq_lens, at::ScalarType dtype,
|
|
|
|
|
const torch::Tensor& query_start_loc, const bool casual,
|
|
|
|
|
const int64_t window_size, const std::string& isa_hint,
|
|
|
|
|
const bool enable_kv_split) {
|
|
|
|
|
cpu_attention::ISA isa;
|
|
|
|
|
if (isa_hint == "amx") {
|
|
|
|
|
isa = cpu_attention::ISA::AMX;
|
|
|
|
|
} else if (isa_hint == "vec") {
|
|
|
|
|
isa = cpu_attention::ISA::VEC;
|
|
|
|
|
} else if (isa_hint == "vec16") {
|
|
|
|
|
isa = cpu_attention::ISA::VEC16;
|
2025-11-22 17:04:36 +00:00
|
|
|
} else if (isa_hint == "neon") {
|
|
|
|
|
isa = cpu_attention::ISA::NEON;
|
2026-02-24 20:55:39 +05:30
|
|
|
} else if (isa_hint == "vxe") {
|
|
|
|
|
isa = cpu_attention::ISA::VXE;
|
2025-11-12 09:43:06 +08:00
|
|
|
} else {
|
|
|
|
|
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cpu_attention::AttentionScheduler::ScheduleInput input;
|
|
|
|
|
input.num_reqs = num_req;
|
|
|
|
|
input.num_heads_q = num_heads_q;
|
|
|
|
|
input.num_heads_kv = num_heads_kv;
|
|
|
|
|
input.head_dim = head_dim;
|
|
|
|
|
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
|
|
|
|
|
input.seq_lens = seq_lens.data_ptr<int32_t>();
|
|
|
|
|
if (window_size != -1) {
|
|
|
|
|
input.left_sliding_window_size = window_size - 1;
|
|
|
|
|
if (casual) {
|
|
|
|
|
input.right_sliding_window_size = 0;
|
|
|
|
|
} else {
|
|
|
|
|
input.right_sliding_window_size = window_size - 1;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
input.left_sliding_window_size = -1;
|
|
|
|
|
if (casual) {
|
|
|
|
|
input.right_sliding_window_size = 0;
|
|
|
|
|
} else {
|
|
|
|
|
input.right_sliding_window_size = -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
input.casual = casual;
|
|
|
|
|
input.isa = isa;
|
|
|
|
|
input.enable_kv_split = enable_kv_split;
|
|
|
|
|
|
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
2026-02-04 09:07:15 +05:30
|
|
|
CPU_ATTN_DISPATCH(head_dim, isa, [&]() {
|
|
|
|
|
input.elem_size = sizeof(scalar_t);
|
|
|
|
|
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
|
|
|
|
|
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
|
|
|
|
|
input.output_buffer_elem_size =
|
|
|
|
|
sizeof(attn_impl::partial_output_buffer_t);
|
|
|
|
|
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
|
|
|
|
|
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
|
2025-11-12 09:43:06 +08:00
|
|
|
});
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
cpu_attention::AttentionScheduler scheduler;
|
|
|
|
|
torch::Tensor metadata = scheduler.schedule(input);
|
|
|
|
|
return metadata;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void cpu_attn_reshape_and_cache(
|
|
|
|
|
const torch::Tensor& key, // [token_num, head_num, head_size]
|
|
|
|
|
const torch::Tensor& value, // [token_num, head_num, head_size]
|
|
|
|
|
torch::Tensor&
|
|
|
|
|
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
|
torch::Tensor&
|
|
|
|
|
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
|
const torch::Tensor& slot_mapping, const std::string& isa) {
|
|
|
|
|
TORCH_CHECK_EQ(key.dim(), 3);
|
|
|
|
|
TORCH_CHECK_EQ(value.dim(), 3);
|
|
|
|
|
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
|
|
|
|
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
|
|
|
|
TORCH_CHECK_EQ(key.stride(2), 1);
|
|
|
|
|
TORCH_CHECK_EQ(value.stride(2), 1);
|
|
|
|
|
|
|
|
|
|
const int64_t token_num = key.size(0);
|
|
|
|
|
const int64_t key_token_num_stride = key.stride(0);
|
|
|
|
|
const int64_t value_token_num_stride = value.stride(0);
|
|
|
|
|
const int64_t head_num = value.size(1);
|
|
|
|
|
const int64_t key_head_num_stride = key.stride(1);
|
|
|
|
|
const int64_t value_head_num_stride = value.stride(1);
|
|
|
|
|
const int64_t num_blocks = key_cache.size(0);
|
|
|
|
|
const int64_t num_blocks_stride = key_cache.stride(0);
|
|
|
|
|
const int64_t cache_head_num_stride = key_cache.stride(1);
|
|
|
|
|
const int64_t block_size = key_cache.size(2);
|
|
|
|
|
const int64_t block_size_stride = key_cache.stride(2);
|
|
|
|
|
const int64_t head_dim = key.size(-1);
|
|
|
|
|
|
|
|
|
|
cpu_attention::ISA isa_tag = [&]() {
|
|
|
|
|
if (isa == "amx") {
|
|
|
|
|
return cpu_attention::ISA::AMX;
|
|
|
|
|
} else if (isa == "vec") {
|
|
|
|
|
return cpu_attention::ISA::VEC;
|
|
|
|
|
} else if (isa == "vec16") {
|
|
|
|
|
return cpu_attention::ISA::VEC16;
|
2025-11-22 17:04:36 +00:00
|
|
|
} else if (isa == "neon") {
|
|
|
|
|
return cpu_attention::ISA::NEON;
|
2026-02-24 20:55:39 +05:30
|
|
|
} else if (isa == "vxe") {
|
|
|
|
|
return cpu_attention::ISA::VXE;
|
2025-11-12 09:43:06 +08:00
|
|
|
} else {
|
|
|
|
|
TORCH_CHECK(false, "Invalid ISA type: " + isa);
|
|
|
|
|
}
|
|
|
|
|
}();
|
|
|
|
|
|
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
|
|
|
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
|
2026-02-04 09:07:15 +05:30
|
|
|
CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() {
|
|
|
|
|
attn_impl::reshape_and_cache(
|
|
|
|
|
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
|
|
|
|
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
|
|
|
|
slot_mapping.data_ptr<int64_t>(), token_num, key_token_num_stride,
|
|
|
|
|
value_token_num_stride, head_num, key_head_num_stride,
|
|
|
|
|
value_head_num_stride, num_blocks, num_blocks_stride,
|
|
|
|
|
cache_head_num_stride, block_size, block_size_stride);
|
2025-11-12 09:43:06 +08:00
|
|
|
});
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void cpu_attention_with_kv_cache(
|
|
|
|
|
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
|
|
|
|
const torch::Tensor&
|
|
|
|
|
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
|
const torch::Tensor&
|
|
|
|
|
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
|
|
|
|
torch::Tensor& output, // [num_tokens, num_heads, head_size]
|
|
|
|
|
const torch::Tensor& query_start_loc, // [num_tokens + 1]
|
|
|
|
|
const torch::Tensor& seq_lens, // [num_tokens]
|
|
|
|
|
const double scale, const bool causal,
|
|
|
|
|
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
|
|
|
|
|
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
|
|
|
|
const torch::Tensor& block_table, // [num_tokens, max_block_num]
|
|
|
|
|
const double softcap, const torch::Tensor& scheduler_metadata,
|
|
|
|
|
const std::optional<torch::Tensor>& s_aux // [num_heads]
|
|
|
|
|
) {
|
|
|
|
|
TORCH_CHECK_EQ(query.dim(), 3);
|
|
|
|
|
TORCH_CHECK_EQ(query.stride(2), 1);
|
|
|
|
|
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
|
|
|
|
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
|
|
|
|
|
|
|
|
|
cpu_attention::AttentionInput input;
|
|
|
|
|
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
|
|
|
|
|
scheduler_metadata.data_ptr());
|
|
|
|
|
input.num_tokens = query.size(0);
|
|
|
|
|
input.num_heads = query.size(1);
|
|
|
|
|
input.num_kv_heads = key_cache.size(1);
|
|
|
|
|
input.block_size = key_cache.size(2);
|
|
|
|
|
input.query = query.data_ptr();
|
|
|
|
|
input.query_num_tokens_stride = query.stride(0);
|
|
|
|
|
input.query_num_heads_stride = query.stride(1);
|
|
|
|
|
input.cache_num_blocks_stride = key_cache.stride(0);
|
|
|
|
|
input.cache_num_kv_heads_stride = key_cache.stride(1);
|
|
|
|
|
input.blt_num_tokens_stride = block_table.stride(0);
|
|
|
|
|
input.key_cache = key_cache.data_ptr();
|
|
|
|
|
input.value_cache = value_cache.data_ptr();
|
|
|
|
|
input.output = output.data_ptr();
|
|
|
|
|
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
|
|
|
|
|
input.seq_lens = seq_lens.data_ptr<int32_t>();
|
|
|
|
|
input.block_table = block_table.data_ptr<int32_t>();
|
|
|
|
|
input.alibi_slopes =
|
|
|
|
|
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
|
|
|
|
|
// For now sink must be bf16
|
|
|
|
|
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
|
|
|
|
|
input.scale = scale;
|
|
|
|
|
input.causal = causal;
|
|
|
|
|
input.sliding_window_left = sliding_window_left;
|
|
|
|
|
input.sliding_window_right = sliding_window_right;
|
|
|
|
|
if (input.causal) {
|
|
|
|
|
// to make boundary calculation easier
|
|
|
|
|
input.sliding_window_right = 0;
|
|
|
|
|
}
|
|
|
|
|
float softcap_fp32 = softcap;
|
|
|
|
|
input.softcap = softcap_fp32;
|
|
|
|
|
|
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
|
|
|
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
|
2026-02-04 09:07:15 +05:30
|
|
|
CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() {
|
|
|
|
|
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
|
|
|
|
|
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
|
|
|
|
mainloop(&input);
|
2025-11-12 09:43:06 +08:00
|
|
|
});
|
|
|
|
|
});
|
|
|
|
|
}
|