[CPU] Refactor CPU attention backend (#27954)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -1,798 +0,0 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using q_load_vec_type = void;
|
||||
using q_vec_type = void;
|
||||
using k_load_vec_type = void;
|
||||
using k_vec_type = void;
|
||||
using qk_acc_vec_type = void;
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using q_load_vec_type = vec_op::FP32Vec4;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using q_load_vec_type = vec_op::FP32Vec8;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures, including x86
|
||||
using q_load_vec_type = vec_op::FP16Vec8;
|
||||
using k_load_vec_type = vec_op::FP16Vec16;
|
||||
using v_load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::BF16Vec32;
|
||||
using k_load_vec_type = vec_op::BF16Vec32;
|
||||
using k_vec_type = vec_op::BF16Vec32;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#else
|
||||
#ifdef __aarch64__
|
||||
#ifndef ARM_BF16_SUPPORT
|
||||
// pass
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
|
||||
const int capacity) {
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
max = max >= data[i] ? max : data[i];
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = std::exp(data[i] - max);
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < size; ++i) {
|
||||
data[i] /= sum;
|
||||
}
|
||||
|
||||
for (; i < capacity; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
|
||||
return {max, sum};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
|
||||
const int capacity,
|
||||
const float alibi_slope,
|
||||
const int start_index,
|
||||
const int seq_len) {
|
||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
|
||||
data[i] = qk;
|
||||
max = max >= qk ? max : qk;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = std::exp(data[i] - max);
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < size; ++i) {
|
||||
data[i] /= sum;
|
||||
}
|
||||
|
||||
for (; i < capacity; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
|
||||
return {max, sum};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data,
|
||||
const int size) {
|
||||
T max = max_data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
max = max >= max_data[i] ? max : max_data[i];
|
||||
}
|
||||
|
||||
T rescaled_sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
T rescale_factor = std::exp(max_data[i] - max);
|
||||
rescaled_sum += rescale_factor * sum_data[i];
|
||||
sum_data[i] *= rescale_factor;
|
||||
}
|
||||
for (int i = 0; i < size; ++i) {
|
||||
sum_data[i] /= rescaled_sum + 1e-8;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
|
||||
struct reduceQKBlockKernel {
|
||||
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
|
||||
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
|
||||
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
|
||||
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
|
||||
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
|
||||
|
||||
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
|
||||
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
|
||||
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
|
||||
|
||||
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
|
||||
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
||||
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
||||
|
||||
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
|
||||
const scalar_t* __restrict__ k_block,
|
||||
float* __restrict__ logits, float scale,
|
||||
const int token_num) {
|
||||
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
||||
|
||||
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
|
||||
if (token_num == BLOCK_SIZE) {
|
||||
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||
q_vec_type q_group_vec(q_load_group_vec);
|
||||
|
||||
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
|
||||
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
|
||||
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||
TOKEN_PER_GROUP);
|
||||
k_vec_type k_group_vec(k_load_group_vec);
|
||||
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||
k_group_vec);
|
||||
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||
token_group_idx * x * TOKEN_PER_GROUP);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||
q_vec_type q_group_vec(q_load_group_vec);
|
||||
for (int token_group_start = 0; token_group_start < group_num;
|
||||
token_group_start += UNROLL_GROUP_NUM) {
|
||||
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
|
||||
[token_group_start, k_block, &q_group_vec,
|
||||
&group_accums](int token_group_idx) {
|
||||
token_group_idx += token_group_start;
|
||||
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||
TOKEN_PER_GROUP);
|
||||
k_vec_type k_group_vec(k_load_group_vec);
|
||||
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||
k_group_vec);
|
||||
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||
token_group_idx * x * TOKEN_PER_GROUP);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int token_group_idx = 0; token_group_idx < group_num;
|
||||
++token_group_idx) {
|
||||
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
|
||||
[&group_accums, logits, scale, token_group_idx](int token_idx) {
|
||||
float dot_v =
|
||||
group_accums[token_group_idx]
|
||||
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
|
||||
TOKEN_PER_GROUP>(token_idx);
|
||||
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
|
||||
dot_v * scale;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int HEAD_PARTITION_SIZE, typename acc_t>
|
||||
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
|
||||
acc_t&& acc) {
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
||||
static_assert(BLOCK_SIZE == ELEM_NUM);
|
||||
vec_op::FP32Vec16 prob_vec(prob);
|
||||
|
||||
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
|
||||
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
|
||||
vec_op::FP32Vec16 fp32_v_vec(v_vec);
|
||||
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
||||
});
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
// Paged attention v1
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||
struct paged_attention_v1_impl {
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
|
||||
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
|
||||
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
|
||||
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
|
||||
|
||||
const int parallel_work_item_num = omp_get_max_threads();
|
||||
|
||||
size_t logits_bytes =
|
||||
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
||||
float* logits = (float*)std::aligned_alloc(
|
||||
64, logits_bytes); // Cacheline alignment for each context token.
|
||||
// [parallel_work_item_num, max_seq_len_padded]
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
int seq_len = seq_lens[seq_idx];
|
||||
const int* seq_block_table =
|
||||
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float* __restrict__ thread_block_logits =
|
||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float* __restrict__ head_block_logits =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute softmax
|
||||
if (alibi_slopes) {
|
||||
reduceSoftmaxAlibi(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||
seq_len);
|
||||
} else {
|
||||
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute value
|
||||
constexpr int head_elem_num_per_partition = 16;
|
||||
constexpr int head_partition_num =
|
||||
HEAD_SIZE / head_elem_num_per_partition;
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t* __restrict__ out_ptr =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float* __restrict__ prob_vec_ptr =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||
head_elem_num_per_partition>(
|
||||
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
if (head_elem_idx % 2 == 0) {
|
||||
vec_op::prefetch(next_v_block_cache_ptr +
|
||||
BLOCK_SIZE * head_elem_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
float value = accums[head_elem_idx].reduce_sum();
|
||||
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
std::free(logits);
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||
num_heads);
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
seq_lens, max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
|
||||
});
|
||||
}
|
||||
|
||||
// Paged attention v2
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||
struct paged_attention_v2_impl {
|
||||
static void call(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads,
|
||||
// max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs,
|
||||
// max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
|
||||
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
|
||||
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int partition_idx = 0; partition_idx < max_num_partitions;
|
||||
++partition_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||
|
||||
if (start_token_idx >= seq_len) continue;
|
||||
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
const bool no_reduce = (partition_num == 1);
|
||||
const int token_num =
|
||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||
start_token_idx);
|
||||
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_token_num =
|
||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
const int* seq_block_table = block_tables +
|
||||
max_num_blocks_per_seq * seq_idx +
|
||||
start_token_idx / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t* __restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
|
||||
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t* __restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float* __restrict__ head_block_logits =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||
}
|
||||
|
||||
std::pair<float, float> max_and_sum;
|
||||
if (alibi_slopes) {
|
||||
max_and_sum = reduceSoftmaxAlibi(
|
||||
logits, token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||
} else {
|
||||
max_and_sum =
|
||||
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
auto&& [max_logit, exp_sum] = max_and_sum;
|
||||
|
||||
scalar_t* __restrict__ output_buffer = nullptr;
|
||||
if (!no_reduce) {
|
||||
auto idx = seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions + partition_idx;
|
||||
max_logits[idx] = max_logit;
|
||||
exp_sums[idx] = exp_sum;
|
||||
output_buffer =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
partition_idx * HEAD_SIZE;
|
||||
} else {
|
||||
output_buffer =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
}
|
||||
|
||||
// Compute value
|
||||
constexpr int head_elem_num_per_partition = 16;
|
||||
constexpr int head_partition_num =
|
||||
HEAD_SIZE / head_elem_num_per_partition;
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t* __restrict__ out_ptr =
|
||||
output_buffer + head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float* __restrict__ prob_vec_ptr =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t* __restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||
head_elem_num_per_partition>(
|
||||
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t* __restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
if (head_elem_idx % 2 == 0) {
|
||||
vec_op::prefetch(next_v_block_cache_ptr +
|
||||
BLOCK_SIZE * head_elem_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
float value = accums[head_elem_idx].reduce_sum();
|
||||
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescale partition softmax and store the factors to exp_sums
|
||||
#pragma omp parallel for collapse(2) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
reducePartitionSoftmax(
|
||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions,
|
||||
exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions,
|
||||
partition_num);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce values
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||
constexpr int head_elem_num_per_group =
|
||||
16; // Note: didn't align with the cacheline size, due to some
|
||||
// HEAD_SIZE didn't align with 64 bytes
|
||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||
const float* __restrict__ rescale_factors = exp_sums;
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1) continue;
|
||||
|
||||
const float* __restrict__ seq_head_rescale_factors =
|
||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions;
|
||||
const scalar_t* __restrict__ seq_head_tmp_out =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
scalar_t* __restrict__ seq_head_output =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
|
||||
vec_op::FP32Vec16 acc;
|
||||
for (int i = 0; i < partition_num; ++i) {
|
||||
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
|
||||
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
|
||||
vec_op::FP32Vec16 fp32_value(value);
|
||||
acc = acc + fp32_value * rescale_factor;
|
||||
}
|
||||
v_load_vec_type cast_acc(acc);
|
||||
cast_acc.save(seq_head_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||
max_num_partitions);
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
int max_num_partitions = exp_sums.size(-1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 32:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 32, BLOCK_SIZE);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 192:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
|
||||
alibi_slopes);
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
|
||||
});
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#if defined(__x86_64__)
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
||||
#else
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& mapping_pairs,
|
||||
const int element_num_per_block,
|
||||
const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size(0);
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset =
|
||||
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
|
||||
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t* source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t* target_ptr = key_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
|
||||
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||
source_ptr = value_cache_ptr + source_offset;
|
||||
target_ptr = value_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void reshape_and_cache_cpu_impl(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int num_tokens,
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x) {
|
||||
const int block_elem_num = num_heads * head_size * block_size;
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx >= 0) {
|
||||
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
||||
int src_value_head_idx =
|
||||
token_idx * value_stride + head_idx * head_size;
|
||||
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
|
||||
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
|
||||
const int64_t block_index = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
scalar_t* target_key_head_ptr = key_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
scalar_t* target_value_head_ptr = value_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
|
||||
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
|
||||
const int64_t target_offset =
|
||||
src_key_idx * block_size + block_offset * x;
|
||||
for (int i = 0; i < x; ++i) {
|
||||
target_key_head_ptr[target_offset + i] =
|
||||
src_key_head_ptr[src_key_idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int src_value_idx = 0; src_value_idx < head_size;
|
||||
++src_value_idx) {
|
||||
const int64_t target_offset =
|
||||
src_value_idx * block_size + block_offset;
|
||||
target_value_head_ptr[target_offset] =
|
||||
src_value_head_ptr[src_value_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void concat_and_cache_mla_cpu_impl(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_tokens, //
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size //
|
||||
) {
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src,
|
||||
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||
int size, int offset) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping) {
|
||||
unsigned num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
||||
num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||
kv_lora_rank, pe_dim, block_size);
|
||||
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
const torch::Tensor& block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
}
|
||||
249
csrc/cpu/cpu_attn.cpp
Normal file
249
csrc/cpu/cpu_attn.cpp
Normal file
@@ -0,0 +1,249 @@
|
||||
#include "cpu_attn_vec.hpp"
|
||||
#include "cpu_attn_vec16.hpp"
|
||||
|
||||
#ifdef CPU_CAPABILITY_AMXBF16
|
||||
#include "cpu_attn_amx.hpp"
|
||||
#define AMX_DISPATCH(...) \
|
||||
case cpu_attention::ISA::AMX: { \
|
||||
using attn_impl = cpu_attention::AttentionImpl<cpu_attention::ISA::AMX, \
|
||||
scalar_t, head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#else
|
||||
#define AMX_DISPATCH(...) case cpu_attention::ISA::AMX:
|
||||
#endif
|
||||
|
||||
#define CPU_ATTN_DISPATCH_CASE(HEAD_DIM, ...) \
|
||||
case HEAD_DIM: { \
|
||||
constexpr size_t head_dim = HEAD_DIM; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define CPU_ATTN_DISPATCH_CASE_HEADDIM(HEAD_DIM, ...) \
|
||||
[&] { \
|
||||
switch (HEAD_DIM) { \
|
||||
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(224, __VA_ARGS__) \
|
||||
CPU_ATTN_DISPATCH_CASE(256, __VA_ARGS__) \
|
||||
default: { \
|
||||
TORCH_CHECK(false, "Invalid CPU attention head_dim: " + \
|
||||
std::to_string(HEAD_DIM)); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define CPU_ATTN_DISPATCH_IMPL(ISA_TYPE, ...) \
|
||||
[&] { \
|
||||
switch (ISA_TYPE) { \
|
||||
AMX_DISPATCH(__VA_ARGS__) \
|
||||
case cpu_attention::ISA::VEC: { \
|
||||
using attn_impl = \
|
||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC, scalar_t, \
|
||||
head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case cpu_attention::ISA::VEC16: { \
|
||||
using attn_impl = \
|
||||
cpu_attention::AttentionImpl<cpu_attention::ISA::VEC16, scalar_t, \
|
||||
head_dim>; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: { \
|
||||
TORCH_CHECK(false, "Invalid CPU attention ISA type."); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
const torch::Tensor& seq_lens, at::ScalarType dtype,
|
||||
const torch::Tensor& query_start_loc, const bool casual,
|
||||
const int64_t window_size, const std::string& isa_hint,
|
||||
const bool enable_kv_split) {
|
||||
cpu_attention::ISA isa;
|
||||
if (isa_hint == "amx") {
|
||||
isa = cpu_attention::ISA::AMX;
|
||||
} else if (isa_hint == "vec") {
|
||||
isa = cpu_attention::ISA::VEC;
|
||||
} else if (isa_hint == "vec16") {
|
||||
isa = cpu_attention::ISA::VEC16;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
|
||||
}
|
||||
|
||||
cpu_attention::AttentionScheduler::ScheduleInput input;
|
||||
input.num_reqs = num_req;
|
||||
input.num_heads_q = num_heads_q;
|
||||
input.num_heads_kv = num_heads_kv;
|
||||
input.head_dim = head_dim;
|
||||
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
|
||||
input.seq_lens = seq_lens.data_ptr<int32_t>();
|
||||
if (window_size != -1) {
|
||||
input.left_sliding_window_size = window_size - 1;
|
||||
if (casual) {
|
||||
input.right_sliding_window_size = 0;
|
||||
} else {
|
||||
input.right_sliding_window_size = window_size - 1;
|
||||
}
|
||||
} else {
|
||||
input.left_sliding_window_size = -1;
|
||||
if (casual) {
|
||||
input.right_sliding_window_size = 0;
|
||||
} else {
|
||||
input.right_sliding_window_size = -1;
|
||||
}
|
||||
}
|
||||
input.casual = casual;
|
||||
input.isa = isa;
|
||||
input.enable_kv_split = enable_kv_split;
|
||||
TORCH_CHECK(casual, "Only supports casual mask for now.");
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
CPU_ATTN_DISPATCH_IMPL(isa, [&]() {
|
||||
input.elem_size = sizeof(scalar_t);
|
||||
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
|
||||
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
|
||||
input.output_buffer_elem_size =
|
||||
sizeof(attn_impl::partial_output_buffer_t);
|
||||
input.max_num_q_per_iter = attn_impl::MaxQHeadNumPerIteration;
|
||||
input.kv_block_alignment = attn_impl::BlockSizeAlignment;
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
cpu_attention::AttentionScheduler scheduler;
|
||||
torch::Tensor metadata = scheduler.schedule(input);
|
||||
return metadata;
|
||||
}
|
||||
|
||||
void cpu_attn_reshape_and_cache(
|
||||
const torch::Tensor& key, // [token_num, head_num, head_size]
|
||||
const torch::Tensor& value, // [token_num, head_num, head_size]
|
||||
torch::Tensor&
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const torch::Tensor& slot_mapping, const std::string& isa) {
|
||||
TORCH_CHECK_EQ(key.dim(), 3);
|
||||
TORCH_CHECK_EQ(value.dim(), 3);
|
||||
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(key.stride(2), 1);
|
||||
TORCH_CHECK_EQ(value.stride(2), 1);
|
||||
|
||||
const int64_t token_num = key.size(0);
|
||||
const int64_t key_token_num_stride = key.stride(0);
|
||||
const int64_t value_token_num_stride = value.stride(0);
|
||||
const int64_t head_num = value.size(1);
|
||||
const int64_t key_head_num_stride = key.stride(1);
|
||||
const int64_t value_head_num_stride = value.stride(1);
|
||||
const int64_t num_blocks = key_cache.size(0);
|
||||
const int64_t num_blocks_stride = key_cache.stride(0);
|
||||
const int64_t cache_head_num_stride = key_cache.stride(1);
|
||||
const int64_t block_size = key_cache.size(2);
|
||||
const int64_t block_size_stride = key_cache.stride(2);
|
||||
const int64_t head_dim = key.size(-1);
|
||||
|
||||
cpu_attention::ISA isa_tag = [&]() {
|
||||
if (isa == "amx") {
|
||||
return cpu_attention::ISA::AMX;
|
||||
} else if (isa == "vec") {
|
||||
return cpu_attention::ISA::VEC;
|
||||
} else if (isa == "vec16") {
|
||||
return cpu_attention::ISA::VEC16;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid ISA type: " + isa);
|
||||
}
|
||||
}();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||
CPU_ATTN_DISPATCH_IMPL(isa_tag, [&]() {
|
||||
attn_impl::reshape_and_cache(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), token_num,
|
||||
key_token_num_stride, value_token_num_stride, head_num,
|
||||
key_head_num_stride, value_head_num_stride, num_blocks,
|
||||
num_blocks_stride, cache_head_num_stride, block_size,
|
||||
block_size_stride);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void cpu_attention_with_kv_cache(
|
||||
const torch::Tensor& query, // [num_tokens, num_heads, head_size]
|
||||
const torch::Tensor&
|
||||
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
const torch::Tensor&
|
||||
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
|
||||
torch::Tensor& output, // [num_tokens, num_heads, head_size]
|
||||
const torch::Tensor& query_start_loc, // [num_tokens + 1]
|
||||
const torch::Tensor& seq_lens, // [num_tokens]
|
||||
const double scale, const bool causal,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, // [num_heads]
|
||||
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
||||
const torch::Tensor& block_table, // [num_tokens, max_block_num]
|
||||
const double softcap, const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux // [num_heads]
|
||||
) {
|
||||
TORCH_CHECK_EQ(query.dim(), 3);
|
||||
TORCH_CHECK_EQ(query.stride(2), 1);
|
||||
TORCH_CHECK_EQ(key_cache.dim(), 4);
|
||||
TORCH_CHECK_EQ(value_cache.dim(), 4);
|
||||
|
||||
cpu_attention::AttentionInput input;
|
||||
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
|
||||
scheduler_metadata.data_ptr());
|
||||
input.num_tokens = query.size(0);
|
||||
input.num_heads = query.size(1);
|
||||
input.num_kv_heads = key_cache.size(1);
|
||||
input.block_size = key_cache.size(2);
|
||||
input.query = query.data_ptr();
|
||||
input.query_num_tokens_stride = query.stride(0);
|
||||
input.query_num_heads_stride = query.stride(1);
|
||||
input.cache_num_blocks_stride = key_cache.stride(0);
|
||||
input.cache_num_kv_heads_stride = key_cache.stride(1);
|
||||
input.blt_num_tokens_stride = block_table.stride(0);
|
||||
input.key_cache = key_cache.data_ptr();
|
||||
input.value_cache = value_cache.data_ptr();
|
||||
input.output = output.data_ptr();
|
||||
input.query_start_loc = query_start_loc.data_ptr<int32_t>();
|
||||
input.seq_lens = seq_lens.data_ptr<int32_t>();
|
||||
input.block_table = block_table.data_ptr<int32_t>();
|
||||
input.alibi_slopes =
|
||||
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
|
||||
// For now sink must be bf16
|
||||
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
|
||||
input.scale = scale;
|
||||
input.causal = causal;
|
||||
input.sliding_window_left = sliding_window_left;
|
||||
input.sliding_window_right = sliding_window_right;
|
||||
if (input.causal) {
|
||||
// to make boundary calculation easier
|
||||
input.sliding_window_right = 0;
|
||||
}
|
||||
float softcap_fp32 = softcap;
|
||||
input.softcap = softcap_fp32;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
|
||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(query.size(2), [&] {
|
||||
CPU_ATTN_DISPATCH_IMPL(input.metadata->isa, [&]() {
|
||||
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
|
||||
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
|
||||
mainloop(&input);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
511
csrc/cpu/cpu_attn_amx.hpp
Normal file
511
csrc/cpu/cpu_attn_amx.hpp
Normal file
@@ -0,0 +1,511 @@
|
||||
#ifndef CPU_ATTN_AMX_HPP
|
||||
#define CPU_ATTN_AMX_HPP
|
||||
|
||||
#include "cpu_attn_impl.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
namespace {
|
||||
// AMX specific
|
||||
constexpr static int64_t AMX_TILE_ROW_BYTES = 64;
|
||||
constexpr static int64_t AMX_TILE_ROW_NUM = 16;
|
||||
constexpr static int64_t AMX_TILE_BYTES = AMX_TILE_ROW_BYTES * AMX_TILE_ROW_NUM;
|
||||
|
||||
typedef struct __tile_config {
|
||||
uint8_t palette_id = 1;
|
||||
uint8_t start_row = 0;
|
||||
uint8_t reserved_0[14] = {0};
|
||||
uint16_t colsb[16] = {0};
|
||||
uint8_t rows[16] = {0};
|
||||
} __tilecfg;
|
||||
|
||||
// 2-2-4 pattern, for 16 < m <= 32
|
||||
// TILE 0, 1: load A matrix, row num should be 16, m - 16
|
||||
// TILE 2, 3: load B matrix, row num should be 16
|
||||
// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m
|
||||
// - 16
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm224 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
|
||||
void* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm224");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm224<c10::BFloat16> {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
c10::BFloat16* __restrict__ a_tile,
|
||||
c10::BFloat16* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
const int32_t k_times =
|
||||
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM;
|
||||
const int64_t a_tile_stride = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// q_buffer is prepacked
|
||||
return AMX_TILE_ROW_BYTES;
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// logits_buffer is row-major
|
||||
return lda * sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// k_cache is prepacked
|
||||
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// v_cache is prepacked
|
||||
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
// k_cache, v_cache are prepacked
|
||||
const int32_t b_tile_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
// logits_buffer, output_buffer are not prepacked
|
||||
float* __restrict__ c_tile_4 = c_tile;
|
||||
float* __restrict__ c_tile_5 =
|
||||
c_tile_4 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
float* __restrict__ c_tile_6 = c_tile + AMX_TILE_ROW_NUM * ldc;
|
||||
float* __restrict__ c_tile_7 =
|
||||
c_tile_6 + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
const int32_t c_tile_stride = ldc * sizeof(float);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(4, c_tile_4, c_tile_stride);
|
||||
_tile_loadd(5, c_tile_5, c_tile_stride);
|
||||
_tile_loadd(6, c_tile_6, c_tile_stride);
|
||||
_tile_loadd(7, c_tile_7, c_tile_stride);
|
||||
} else {
|
||||
_tile_zero(4);
|
||||
_tile_zero(5);
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_tile_stride);
|
||||
_tile_dpbf16ps(4, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_tile_stride);
|
||||
_tile_dpbf16ps(5, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_dpbf16ps(6, 1, 2);
|
||||
_tile_dpbf16ps(7, 1, 3);
|
||||
|
||||
// update ptrs
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// Q buffer is prepacked
|
||||
a_tile_0 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// P buffer is not prepacked
|
||||
a_tile_0 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
b_tile_2 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
_tile_stored(4, c_tile_4, c_tile_stride);
|
||||
_tile_stored(5, c_tile_5, c_tile_stride);
|
||||
_tile_stored(6, c_tile_6, c_tile_stride);
|
||||
_tile_stored(7, c_tile_7, c_tile_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
const int32_t m_0 = AMX_TILE_ROW_NUM;
|
||||
const int32_t m_1 = m - AMX_TILE_ROW_NUM;
|
||||
config.rows[0] = m_0;
|
||||
config.rows[1] = m_1;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = m_0;
|
||||
config.rows[5] = m_0;
|
||||
config.rows[6] = m_1;
|
||||
config.rows[7] = m_1;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
|
||||
// 1-2-2 pattern, for 0 < m <= 16
|
||||
// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be
|
||||
// m, m
|
||||
// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row
|
||||
// num should be 16
|
||||
// TILE 6, 7, (6, 7): store results C matrix, row num should be
|
||||
// m
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm122 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile,
|
||||
void* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
TORCH_CHECK(false, "Unsupported kv cache type for TileGemm122");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class TileGemm122<c10::BFloat16> {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
c10::BFloat16* __restrict__ a_tile,
|
||||
c10::BFloat16* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
c10::BFloat16* __restrict__ a_tile_0 = a_tile;
|
||||
c10::BFloat16* __restrict__ a_tile_1 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// q_buffer is prepacked
|
||||
return a_tile + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// logits_buffer is row-major
|
||||
return a_tile + AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
const int64_t a_tile_stride = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// q_buffer is prepacked
|
||||
return AMX_TILE_ROW_BYTES;
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// logits_buffer is row-major
|
||||
return lda * sizeof(c10::BFloat16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
|
||||
c10::BFloat16* __restrict__ b_tile_2 = b_tile;
|
||||
c10::BFloat16* __restrict__ b_tile_3 = [&]() {
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// k_cache is prepacked
|
||||
return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// v_cache is prepacked
|
||||
return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unreachable");
|
||||
}
|
||||
}();
|
||||
c10::BFloat16* __restrict__ b_tile_4 =
|
||||
b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
c10::BFloat16* __restrict__ b_tile_5 =
|
||||
b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
int64_t b_stride = AMX_TILE_ROW_BYTES;
|
||||
|
||||
float* __restrict__ c_tile_6 = c_tile;
|
||||
float* __restrict__ c_tile_7 = c_tile + AMX_TILE_ROW_BYTES / sizeof(float);
|
||||
int64_t c_stride = ldc * sizeof(float);
|
||||
|
||||
const int32_t k_times =
|
||||
dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16));
|
||||
const int32_t k_group_times = k_times / 2;
|
||||
const bool has_tail = (k_times % 2 == 1);
|
||||
|
||||
if (accum_c) {
|
||||
_tile_loadd(6, c_tile_6, c_stride);
|
||||
_tile_loadd(7, c_tile_7, c_stride);
|
||||
} else {
|
||||
_tile_zero(6);
|
||||
_tile_zero(7);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k < k_group_times; ++k) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
_tile_loadd(1, a_tile_1, a_tile_stride);
|
||||
_tile_stream_loadd(4, b_tile_4, b_stride);
|
||||
_tile_dpbf16ps(6, 1, 4);
|
||||
_tile_stream_loadd(5, b_tile_5, b_stride);
|
||||
_tile_dpbf16ps(7, 1, 5);
|
||||
|
||||
// update ptrs
|
||||
if constexpr (phase == AttentionGemmPhase::QK) {
|
||||
// Q buffer is prepacked
|
||||
a_tile_0 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
} else if constexpr (phase == AttentionGemmPhase::PV) {
|
||||
// P buffer is not prepacked
|
||||
a_tile_0 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
a_tile_1 += 2 * AMX_TILE_ROW_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
b_tile_2 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_3 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_4 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
b_tile_5 += 2 * AMX_TILE_BYTES / sizeof(c10::BFloat16);
|
||||
}
|
||||
|
||||
if (has_tail) {
|
||||
_tile_loadd(0, a_tile_0, a_tile_stride);
|
||||
_tile_stream_loadd(2, b_tile_2, b_stride);
|
||||
_tile_dpbf16ps(6, 0, 2);
|
||||
_tile_stream_loadd(3, b_tile_3, b_stride);
|
||||
_tile_dpbf16ps(7, 0, 3);
|
||||
}
|
||||
|
||||
_tile_stored(6, c_tile_6, c_stride);
|
||||
_tile_stored(7, c_tile_7, c_stride);
|
||||
}
|
||||
|
||||
FORCE_INLINE static void init_tile_config(int32_t m, __tilecfg& config) {
|
||||
config.rows[0] = m;
|
||||
config.rows[1] = m;
|
||||
config.rows[2] = AMX_TILE_ROW_NUM;
|
||||
config.rows[3] = AMX_TILE_ROW_NUM;
|
||||
config.rows[4] = AMX_TILE_ROW_NUM;
|
||||
config.rows[5] = AMX_TILE_ROW_NUM;
|
||||
config.rows[6] = m;
|
||||
config.rows[7] = m;
|
||||
_tile_loadconfig(&config);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = scalar_t;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = scalar_t;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
AMX_TILE_ROW_BYTES /
|
||||
sizeof(kv_cache_t); // KV token num unit of QK and PV phases
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 32;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::AMX;
|
||||
constexpr static bool scale_on_logits = true;
|
||||
|
||||
public:
|
||||
AttentionImpl() : current_q_head_num_(0) {
|
||||
// Use all columns in AMX tiles
|
||||
vec_op::unroll_loop<int, 8>([&](int i) { amx_tile_config_.colsb[i] = 64; });
|
||||
}
|
||||
|
||||
~AttentionImpl() { _tile_release(); }
|
||||
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
if (q_head_num > AMX_TILE_ROW_NUM) {
|
||||
if (q_head_num != current_q_head_num_) {
|
||||
current_q_head_num_ = q_head_num;
|
||||
TileGemm224<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
|
||||
}
|
||||
attention<TileGemm224<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
} else {
|
||||
if (q_head_num != current_q_head_num_) {
|
||||
current_q_head_num_ = q_head_num;
|
||||
TileGemm122<kv_cache_t>::init_tile_config(q_head_num, amx_tile_config_);
|
||||
}
|
||||
attention<TileGemm122<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
}
|
||||
|
||||
// k_cache_token_group_stride: stride of K cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment * head_dim;
|
||||
}
|
||||
|
||||
// v_cache_token_group_stride: stride of V cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment * (AMX_TILE_ROW_BYTES / 4);
|
||||
}
|
||||
|
||||
// v_cache_head_group_stride: stride of V cache when move to next
|
||||
// HeadDimAlignment head dims in a block
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return block_size * HeadDimAlignment;
|
||||
}
|
||||
|
||||
static void copy_q_heads_tile(
|
||||
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
|
||||
scalar_t* __restrict__ q_buffer, const int32_t q_num,
|
||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, const float scale) {
|
||||
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
|
||||
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
|
||||
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
|
||||
constexpr int64_t head_elem_num_pre_block =
|
||||
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
|
||||
|
||||
int32_t idx = 0;
|
||||
int8_t* __restrict__ q_buffer_iter = reinterpret_cast<int8_t*>(q_buffer);
|
||||
for (int32_t q_num_idx = 0; q_num_idx < q_num;
|
||||
++q_num_idx, src += q_num_stride) {
|
||||
scalar_t* __restrict__ src_iter = src;
|
||||
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv;
|
||||
++q_head_idx, src_iter += q_head_stride) {
|
||||
vec_op::unroll_loop<int32_t, head_size_block_num>(
|
||||
[&](int32_t head_size_block_idx) {
|
||||
// Use INT8Vec64 for 64 bytes block
|
||||
vec_op::INT8Vec64 vec(src_iter + head_size_block_idx *
|
||||
head_elem_num_pre_block);
|
||||
vec.save(q_buffer_iter + head_size_block_idx * AMX_TILE_BYTES);
|
||||
});
|
||||
|
||||
++idx;
|
||||
q_buffer_iter += AMX_TILE_ROW_BYTES;
|
||||
if ((idx & (AMX_TILE_ROW_NUM - 1)) == 0) {
|
||||
// head is in another amx tile
|
||||
q_buffer_iter -= AMX_TILE_ROW_NUM * AMX_TILE_ROW_BYTES;
|
||||
q_buffer_iter += head_size_block_num * AMX_TILE_BYTES;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reshape KV to AMX friendly layout
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
// For AMX 2D tiles, size of each line is 64 bytes
|
||||
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
|
||||
// For AMX B martix, N always is 16
|
||||
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
|
||||
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
|
||||
// For now suppose block_size is divisible by amx_tile_column_num
|
||||
TORCH_CHECK_EQ(block_size % amx_b_tile_k_size, 0);
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
const int64_t pos = slot_mapping[token_idx];
|
||||
if (pos < 0) {
|
||||
// skip
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t block_idx = pos / block_size;
|
||||
const int64_t block_offset = pos % block_size;
|
||||
{
|
||||
// Write Key
|
||||
// Head elements should be packed as quand-words and stored in token
|
||||
// groups with (quadword_stride/4) tokens
|
||||
constexpr int64_t token_num_per_group = amx_tile_row_size / 4;
|
||||
static_assert(head_dim % (4 / sizeof(scalar_t)) == 0);
|
||||
constexpr int64_t quadword_num = head_dim / (4 / sizeof(scalar_t));
|
||||
const int32_t* key_start_quadword_ptr =
|
||||
reinterpret_cast<const int32_t*>(
|
||||
key + token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride);
|
||||
const int64_t group_idx = block_offset / token_num_per_group;
|
||||
const int64_t group_offset = block_offset % token_num_per_group;
|
||||
constexpr int64_t quadword_num_per_group =
|
||||
token_num_per_group * quadword_num;
|
||||
int32_t* key_cache_start_ptr =
|
||||
reinterpret_cast<int32_t*>(key_cache +
|
||||
block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride) +
|
||||
group_idx * quadword_num_per_group + group_offset;
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (int64_t i = 0, j = 0; j < quadword_num;
|
||||
i += token_num_per_group, ++j) {
|
||||
key_cache_start_ptr[i] = key_start_quadword_ptr[j];
|
||||
}
|
||||
}
|
||||
{
|
||||
// Write Value
|
||||
// Different from Key, block_size dimension is packed rather than
|
||||
// head_size dimension block_size dimension is packed as quand-words;
|
||||
constexpr int64_t token_num_per_sub_group = 4 / sizeof(scalar_t);
|
||||
const int64_t token_num_per_group = block_size;
|
||||
constexpr int64_t head_elems_per_group = amx_b_tile_n_size;
|
||||
const int64_t group_size = token_num_per_group * head_elems_per_group;
|
||||
// For now suppose head_dim is divisible by amx_b_tile_n_size
|
||||
static_assert(head_dim % head_elems_per_group == 0);
|
||||
constexpr int64_t group_num = head_dim / head_elems_per_group;
|
||||
const int64_t sub_group_idx = block_offset / token_num_per_sub_group;
|
||||
const int64_t sub_group_offset =
|
||||
block_offset % token_num_per_sub_group;
|
||||
|
||||
const scalar_t* value_start_ptr = value +
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* value_cache_start_ptr =
|
||||
value_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride +
|
||||
sub_group_idx * token_num_per_sub_group * amx_b_tile_n_size +
|
||||
sub_group_offset;
|
||||
|
||||
for (int64_t i = 0; i < group_num; ++i) {
|
||||
#pragma GCC unroll head_elems_per_group
|
||||
for (int64_t j = 0, k = 0; j < head_elems_per_group;
|
||||
++j, k += token_num_per_sub_group) {
|
||||
value_cache_start_ptr[k] = value_start_ptr[j];
|
||||
}
|
||||
value_start_ptr += head_elems_per_group;
|
||||
value_cache_start_ptr += group_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
alignas(64) __tilecfg amx_tile_config_;
|
||||
int32_t current_q_head_num_;
|
||||
};
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
1977
csrc/cpu/cpu_attn_impl.hpp
Normal file
1977
csrc/cpu/cpu_attn_impl.hpp
Normal file
File diff suppressed because it is too large
Load Diff
63
csrc/cpu/cpu_attn_macros.h
Normal file
63
csrc/cpu/cpu_attn_macros.h
Normal file
@@ -0,0 +1,63 @@
|
||||
#ifndef CPU_ATTN_MACROS_H
|
||||
#define CPU_ATTN_MACROS_H
|
||||
|
||||
// x86_64
|
||||
#ifdef __x86_64__
|
||||
#define FAST_SPINNING _mm_pause();
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#define DEFINE_FAST_EXP \
|
||||
const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); \
|
||||
const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); \
|
||||
const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); \
|
||||
const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); \
|
||||
const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); \
|
||||
const __m512 vec_exp_log2ef = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); \
|
||||
const __m512 vec_half = _mm512_set1_ps(0.5f); \
|
||||
const __m512 vec_one = _mm512_set1_ps(1.f); \
|
||||
const __m512 vec_zero = _mm512_set1_ps(0.f); \
|
||||
const __m512 vec_two = _mm512_set1_ps(2.f); \
|
||||
const __m512 vec_ln2f = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); \
|
||||
const __m512 vec_ln_flt_min = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); \
|
||||
const __m512 vec_ln_flt_max = \
|
||||
_mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); \
|
||||
const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); \
|
||||
const int n_mantissa_bits = 23; \
|
||||
auto fast_exp = [&](vec_op::FP32Vec16& vec) __attribute__(( \
|
||||
always_inline)) { \
|
||||
__m512 values = vec.reg; \
|
||||
auto less_ln_flt_min_mask = \
|
||||
_mm512_cmp_ps_mask(values, vec_ln_flt_min, 1 /*_CMP_LT_OS*/); \
|
||||
auto vec_src = _mm512_min_ps(values, vec_ln_flt_max); \
|
||||
vec_src = _mm512_max_ps(vec_src, vec_ln_flt_min); \
|
||||
auto vec_fx = _mm512_fmadd_ps(vec_src, vec_exp_log2ef, vec_half); \
|
||||
auto vec_fx_i = _mm512_cvt_roundps_epi32( \
|
||||
vec_fx, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); \
|
||||
vec_fx = _mm512_cvtepi32_ps(vec_fx_i); \
|
||||
auto vec_exp_poly = _mm512_fnmadd_ps(vec_fx, vec_ln2f, vec_src); \
|
||||
auto vec_res = \
|
||||
_mm512_fmadd_ps(vec_exp_poly, vec_factorial_5, vec_factorial_4); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_3); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_2); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_factorial_1); \
|
||||
vec_res = _mm512_fmadd_ps(vec_exp_poly, vec_res, vec_one); \
|
||||
auto vec_exp_number = _mm512_sub_ps(vec_fx, vec_one); \
|
||||
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number); \
|
||||
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127); \
|
||||
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits); \
|
||||
auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i); \
|
||||
vec_two_pow_n = _mm512_mask_blend_ps(less_ln_flt_min_mask, \
|
||||
vec_two_pow_n, vec_zero); \
|
||||
vec_res = _mm512_mul_ps(vec_res, vec_two_pow_n); \
|
||||
vec_res = _mm512_mul_ps(vec_res, vec_two); \
|
||||
vec_op::FP32Vec16 res(vec_res); \
|
||||
return res; \
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
248
csrc/cpu/cpu_attn_vec.hpp
Normal file
248
csrc/cpu/cpu_attn_vec.hpp
Normal file
@@ -0,0 +1,248 @@
|
||||
#ifndef CPU_ATTN_VEC_HPP
|
||||
#define CPU_ATTN_VEC_HPP
|
||||
|
||||
#include "cpu_attn_impl.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
// 8-2-16 pattern, 8 regs for A, 2 regs for B, 16 regs for C, [8, K] @ [k, 32]
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm82 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
switch (m_size) {
|
||||
case 1:
|
||||
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 3:
|
||||
case 4:
|
||||
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 7:
|
||||
case 8:
|
||||
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M <= 8);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
kv_cache_t* __restrict__ curr_b_1 = b_tile + 16;
|
||||
float* __restrict__ curr_c_0 = c_tile;
|
||||
float* __restrict__ curr_c_1 = c_tile + 16;
|
||||
|
||||
vec_op::FP32Vec16 c_regs[M * 2];
|
||||
if (accum_c) {
|
||||
float* __restrict__ curr_m_c_0 = curr_c_0;
|
||||
float* __restrict__ curr_m_c_1 = curr_c_1;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2] = vec_op::FP32Vec16(curr_m_c_0);
|
||||
c_regs[i * 2 + 1] = vec_op::FP32Vec16(curr_m_c_1);
|
||||
|
||||
// update
|
||||
curr_m_c_0 += ldc;
|
||||
curr_m_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
|
||||
float* __restrict__ curr_a = a_tile;
|
||||
for (int32_t k = 0; k < dynamic_k_size; ++k) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
load_vec_t b_1_reg(curr_b_1);
|
||||
vec_op::FP32Vec16 fp32_b_1_reg(b_1_reg);
|
||||
|
||||
float* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
float v = *curr_m_a;
|
||||
vec_op::FP32Vec16 a_reg(v);
|
||||
c_regs[i * 2] = c_regs[i * 2] + a_reg * fp32_b_0_reg;
|
||||
c_regs[i * 2 + 1] = c_regs[i * 2 + 1] + a_reg * fp32_b_1_reg;
|
||||
|
||||
// update
|
||||
curr_m_a += lda;
|
||||
});
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += ldb;
|
||||
curr_b_1 += ldb;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i * 2].save(curr_c_0);
|
||||
c_regs[i * 2 + 1].save(curr_c_1);
|
||||
|
||||
// update
|
||||
curr_c_0 += ldc;
|
||||
curr_c_1 += ldc;
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This is a general but naive implementation based on vector instructions
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = float;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
32; // KV token num unit of QK and PV phases
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
32; // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 8;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::VEC;
|
||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||
|
||||
public:
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
attention<TileGemm82<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
|
||||
// k_cache_token_group_stride: stride of K cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
|
||||
// block_size], row-major
|
||||
}
|
||||
|
||||
// v_cache_token_group_stride: stride of V cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
|
||||
// head_dim], row-major
|
||||
}
|
||||
|
||||
// v_cache_head_group_stride: stride of V cache when move to next
|
||||
// HeadDimAlignment head dims in a block
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
|
||||
// row-major
|
||||
}
|
||||
|
||||
// Copy q to q_buffer and cast it to fp32
|
||||
static void copy_q_heads_tile(
|
||||
scalar_t* __restrict__ src, // [q_num, q_heads_per_kv, head_size]
|
||||
float* __restrict__ q_buffer, const int32_t q_num,
|
||||
const int32_t q_heads_per_kv, const int64_t q_num_stride,
|
||||
const int64_t q_head_stride, float scale) {
|
||||
static_assert(head_dim % 16 == 0);
|
||||
constexpr int32_t unroll_size = head_dim / 16;
|
||||
using load_vec_t = typename VecTypeTrait<scalar_t>::vec_t;
|
||||
|
||||
vec_op::FP32Vec16 scale_vec(scale);
|
||||
for (int32_t q_num_idx = 0; q_num_idx < q_num; ++q_num_idx) {
|
||||
for (int32_t q_head_idx = 0; q_head_idx < q_heads_per_kv; ++q_head_idx) {
|
||||
scalar_t* __restrict__ curr_q =
|
||||
src + q_num_idx * q_num_stride + q_head_idx * q_head_stride;
|
||||
float* __restrict__ curr_q_buffer =
|
||||
q_buffer + q_num_idx * q_heads_per_kv * head_dim +
|
||||
q_head_idx * head_dim;
|
||||
|
||||
vec_op::unroll_loop<int32_t, unroll_size>([&](int32_t i) {
|
||||
load_vec_t vec(curr_q);
|
||||
vec_op::FP32Vec16 fp32_vec(vec);
|
||||
fp32_vec = fp32_vec * scale_vec;
|
||||
fp32_vec.save(curr_q_buffer);
|
||||
|
||||
curr_q += 16;
|
||||
curr_q_buffer += 16;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reshape K as column-major and V as row-major
|
||||
static void reshape_and_cache(
|
||||
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
|
||||
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
|
||||
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
|
||||
const int64_t head_num, const int64_t key_head_num_stride,
|
||||
const int64_t value_head_num_stride, const int64_t num_blocks,
|
||||
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
|
||||
const int64_t block_size, const int64_t block_size_stride) {
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
|
||||
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
|
||||
const int64_t pos = slot_mapping[token_idx];
|
||||
if (pos < 0) {
|
||||
// skip
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t block_idx = pos / block_size;
|
||||
const int64_t block_offset = pos % block_size;
|
||||
{
|
||||
// Write Key as column-major
|
||||
const scalar_t* key_start_ptr = key +
|
||||
token_idx * key_token_num_stride +
|
||||
head_idx * key_head_num_stride;
|
||||
scalar_t* key_cache_start_ptr =
|
||||
key_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset;
|
||||
|
||||
#pragma GCC unroll 8
|
||||
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
|
||||
key_cache_start_ptr[j] = key_start_ptr[i];
|
||||
}
|
||||
}
|
||||
{
|
||||
// Write Value as row-major
|
||||
const scalar_t* value_start_ptr = value +
|
||||
token_idx * value_token_num_stride +
|
||||
head_idx * value_head_num_stride;
|
||||
scalar_t* value_cache_start_ptr =
|
||||
value_cache + block_idx * num_blocks_stride +
|
||||
head_idx * cache_head_num_stride + block_offset * head_dim;
|
||||
std::memcpy(value_cache_start_ptr, value_start_ptr,
|
||||
sizeof(scalar_t) * head_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
171
csrc/cpu/cpu_attn_vec16.hpp
Normal file
171
csrc/cpu/cpu_attn_vec16.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
#ifndef CPU_ATTN_VEC16_HPP
|
||||
#define CPU_ATTN_VEC16_HPP
|
||||
|
||||
#include "cpu_attn_vec.hpp"
|
||||
|
||||
namespace cpu_attention {
|
||||
|
||||
namespace {
|
||||
// 16-1-16 pattern, 16 regs for A, 1 regs for B, 16 regs for C, [16, K] @ [k,
|
||||
// 16]
|
||||
template <typename kv_cache_t>
|
||||
class TileGemm161 {
|
||||
public:
|
||||
template <AttentionGemmPhase phase, int32_t k_size>
|
||||
FORCE_INLINE static void gemm(const int32_t m_size,
|
||||
float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size,
|
||||
const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
switch (m_size) {
|
||||
case 1:
|
||||
gemm_micro<1>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 2:
|
||||
gemm_micro<2>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 3:
|
||||
case 4:
|
||||
gemm_micro<4>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
gemm_micro<6>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 7:
|
||||
case 8:
|
||||
gemm_micro<8>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 9:
|
||||
case 10:
|
||||
case 11:
|
||||
case 12:
|
||||
gemm_micro<12>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
case 13:
|
||||
case 14:
|
||||
case 15:
|
||||
case 16:
|
||||
gemm_micro<16>(a_tile, b_tile, c_tile, lda, ldb, ldc, block_size,
|
||||
dynamic_k_size, accum_c);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t M>
|
||||
static void gemm_micro(float* __restrict__ a_tile,
|
||||
kv_cache_t* __restrict__ b_tile,
|
||||
float* __restrict__ c_tile, const int64_t lda,
|
||||
const int64_t ldb, const int64_t ldc,
|
||||
const int32_t block_size, const int32_t dynamic_k_size,
|
||||
const bool accum_c) {
|
||||
static_assert(0 < M <= 16);
|
||||
using load_vec_t = typename VecTypeTrait<kv_cache_t>::vec_t;
|
||||
|
||||
kv_cache_t* __restrict__ curr_b_0 = b_tile;
|
||||
float* __restrict__ curr_c_0 = c_tile;
|
||||
|
||||
vec_op::FP32Vec16 c_regs[M];
|
||||
if (accum_c) {
|
||||
float* __restrict__ curr_m_c_0 = curr_c_0;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i] = vec_op::FP32Vec16(curr_m_c_0);
|
||||
|
||||
// update
|
||||
curr_m_c_0 += ldc;
|
||||
});
|
||||
}
|
||||
|
||||
float* __restrict__ curr_a = a_tile;
|
||||
for (int32_t k = 0; k < dynamic_k_size; ++k) {
|
||||
load_vec_t b_0_reg(curr_b_0);
|
||||
vec_op::FP32Vec16 fp32_b_0_reg(b_0_reg);
|
||||
|
||||
float* __restrict__ curr_m_a = curr_a;
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
float v = *curr_m_a;
|
||||
vec_op::FP32Vec16 a_reg(v);
|
||||
c_regs[i] = c_regs[i] + a_reg * fp32_b_0_reg;
|
||||
|
||||
// update
|
||||
curr_m_a += lda;
|
||||
});
|
||||
|
||||
// update
|
||||
curr_a += 1;
|
||||
curr_b_0 += ldb;
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int32_t, M>([&](int32_t i) {
|
||||
c_regs[i].save(curr_c_0);
|
||||
|
||||
// update
|
||||
curr_c_0 += ldc;
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// This is a general but naive implementation based on vector instructions
|
||||
template <typename scalar_t, int64_t head_dim>
|
||||
class AttentionImpl<ISA::VEC16, scalar_t, head_dim>
|
||||
: public AttentionImpl<ISA::VEC, scalar_t, head_dim> {
|
||||
public:
|
||||
using query_t = scalar_t;
|
||||
using q_buffer_t = float;
|
||||
using kv_cache_t = scalar_t;
|
||||
using logits_buffer_t = float;
|
||||
using partial_output_buffer_t = float;
|
||||
using prob_buffer_t = float;
|
||||
|
||||
constexpr static int64_t BlockSizeAlignment =
|
||||
16; // KV token num unit of QK and PV phases
|
||||
constexpr static int64_t HeadDimAlignment =
|
||||
16; // headdim num unit of PV phase
|
||||
constexpr static int64_t MaxQHeadNumPerIteration = 16;
|
||||
constexpr static int64_t HeadDim = head_dim;
|
||||
constexpr static ISA ISAType = ISA::VEC16;
|
||||
constexpr static bool scale_on_logits = false; // apply scale on q_buffer
|
||||
|
||||
public:
|
||||
template <template <typename tile_gemm_t> typename attention>
|
||||
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
|
||||
attention<TileGemm161<kv_cache_t>> attention_iteration;
|
||||
attention_iteration(CPU_ATTENTION_PARAMS);
|
||||
}
|
||||
|
||||
// k_cache_token_group_stride: stride of K cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t k_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return BlockSizeAlignment; // layout of k_cache block is [head_dim,
|
||||
// block_size], row-major
|
||||
}
|
||||
|
||||
// v_cache_token_group_stride: stride of V cache when move to next
|
||||
// BlockSizeAlignment tokens in a block
|
||||
constexpr static int64_t v_cache_token_group_stride(
|
||||
const int32_t block_size) {
|
||||
return head_dim * BlockSizeAlignment; // layout of v_cache is [block_size,
|
||||
// head_dim], row-major
|
||||
}
|
||||
|
||||
// v_cache_head_group_stride: stride of V cache when move to next
|
||||
// HeadDimAlignment head dims in a block
|
||||
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
|
||||
return HeadDimAlignment; // layout of v_cache is [block_size, head_dim],
|
||||
// row-major
|
||||
}
|
||||
};
|
||||
} // namespace cpu_attention
|
||||
|
||||
#endif
|
||||
@@ -40,6 +40,23 @@ namespace vec_op {
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
// Function to get the timestamp using RDTSCP
|
||||
FORCE_INLINE uint64_t bench_timestamp() {
|
||||
unsigned int cycles_low, cycles_high;
|
||||
asm volatile(
|
||||
".intel_syntax noprefix\n\t"
|
||||
"CPUID\n\t" // Serialize instruction stream to ensure previous
|
||||
// instructions complete
|
||||
"RDTSCP\n\t" // Read TSC and core ID
|
||||
"mov %0, edx\n\t" // Store high 32 bits of TSC
|
||||
"mov %1, eax\n\t" // Store low 32 bits of TSC
|
||||
".att_syntax"
|
||||
: "=r"(cycles_high), "=r"(cycles_low)::"rax", "rbx", "rcx",
|
||||
"rdx" // Clobbered registers
|
||||
);
|
||||
return (uint64_t)cycles_high << 32 | cycles_low;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
@@ -407,6 +424,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
||||
|
||||
float get_last_elem() const { return _mm512_cvtss_f32(reg); }
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@@ -446,9 +465,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16& data)
|
||||
: reg_low(data.reg_low), reg_high(data.reg_high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg_low((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
|
||||
@@ -504,6 +520,32 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
_mm256_div_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_max_ps(reg_low, b.reg_low),
|
||||
_mm256_max_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
float reduce_max() const {
|
||||
__m256 v = _mm256_max_ps(reg_low, reg_high);
|
||||
// Permute to compare elements within 128-bit lanes
|
||||
__m256 v_shuffled = _mm256_permute_ps(
|
||||
v, 0b00001011); // Swap halves within each 128-bit lane
|
||||
__m256 v_max = _mm256_max_ps(v, v_shuffled);
|
||||
|
||||
v_shuffled = _mm256_permute_ps(
|
||||
v_max, 0b00000001); // Shuffle elements within each 128-bit lane
|
||||
v_max = _mm256_max_ps(v_max, v_shuffled);
|
||||
|
||||
// Permute to compare elements between 128-bit lanes
|
||||
v_shuffled =
|
||||
_mm256_permute2f128_ps(v_max, v_max, 0b00000001); // Swap 128-bit lanes
|
||||
v_max = _mm256_max_ps(v_max, v_shuffled);
|
||||
|
||||
// At this point, the maximum value is present in all elements of v_max.
|
||||
// Extract the first element for the scalar result.
|
||||
return _mm256_cvtss_f32(v_max); // Extract the lowest 32-bit float
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
FP32Vec8 low = FP32Vec8(reg_low);
|
||||
FP32Vec8 high = FP32Vec8(reg_high);
|
||||
@@ -642,7 +684,7 @@ inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm256_insertf128_si256(
|
||||
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
|
||||
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
FP16Vec8(FP32Vec8(v.reg_high)).reg, 1)) {}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
#include "scratchpad_manager.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
@@ -22,23 +23,6 @@ void release_dnnl_matmul_handler(int64_t handler) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void DNNLScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
|
||||
static DNNLScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
|
||||
@@ -59,30 +59,6 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB
|
||||
|
||||
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
|
||||
|
||||
DNNLScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
|
||||
23
csrc/cpu/scratchpad_manager.cpp
Normal file
23
csrc/cpu/scratchpad_manager.cpp
Normal file
@@ -0,0 +1,23 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "scratchpad_manager.h"
|
||||
|
||||
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void DNNLScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
if (ptr_ != nullptr) {
|
||||
std::free(ptr_);
|
||||
}
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
|
||||
static DNNLScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
31
csrc/cpu/scratchpad_manager.h
Normal file
31
csrc/cpu/scratchpad_manager.h
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef SCRATCHPAD_MANAGER_H
|
||||
#define SCRATCHPAD_MANAGER_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
|
||||
class DNNLScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024; // 4KB
|
||||
|
||||
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
|
||||
|
||||
DNNLScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -192,7 +192,7 @@ class SHMManager {
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(torch::get_num_threads()),
|
||||
_thread_num(omp_get_max_threads()),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
|
||||
@@ -74,25 +74,35 @@ at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
torch::Tensor get_scheduler_metadata(
|
||||
const int64_t num_req, const int64_t num_heads_q,
|
||||
const int64_t num_heads_kv, const int64_t head_dim,
|
||||
const torch::Tensor& seq_lens, at::ScalarType dtype,
|
||||
const torch::Tensor& query_start_loc, const bool casual,
|
||||
const int64_t window_size, const std::string& isa_hint,
|
||||
const bool enable_kv_split);
|
||||
|
||||
void cpu_attn_reshape_and_cache(const torch::Tensor& key,
|
||||
const torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
const torch::Tensor& slot_mapping,
|
||||
const std::string& isa);
|
||||
|
||||
void cpu_attention_with_kv_cache(
|
||||
const torch::Tensor& query, const torch::Tensor& key_cache,
|
||||
const torch::Tensor& value_cache, torch::Tensor& output,
|
||||
const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
|
||||
const double scale, const bool causal,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const int64_t sliding_window_left, const int64_t sliding_window_right,
|
||||
const torch::Tensor& block_table, const double softcap,
|
||||
const torch::Tensor& scheduler_metadata,
|
||||
const std::optional<torch::Tensor>& s_aux);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached keys/values
|
||||
// using PagedAttention.
|
||||
ops.def(
|
||||
"paged_attention_v1("
|
||||
" Tensor! out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
|
||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||
|
||||
ops.def(
|
||||
"dynamic_4bit_int_moe("
|
||||
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
|
||||
@@ -102,20 +112,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
|
||||
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
|
||||
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||
|
||||
// Activation ops
|
||||
|
||||
// Activation function used in SwiGLU.
|
||||
@@ -259,37 +255,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
// Cache ops
|
||||
// Swap in (out) the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
|
||||
|
||||
// Copy the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
|
||||
// CPU attention kernels
|
||||
ops.def(
|
||||
"get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
|
||||
"int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
|
||||
"query_start_loc, bool casual, int window_size, str isa_hint, bool "
|
||||
"enable_kv_split) -> Tensor",
|
||||
&get_scheduler_metadata);
|
||||
ops.def(
|
||||
"cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
|
||||
"key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
|
||||
"isa) -> ()",
|
||||
&cpu_attn_reshape_and_cache);
|
||||
ops.def(
|
||||
"cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
|
||||
"value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
|
||||
"seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
|
||||
"sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
|
||||
"float softcap, Tensor sheduler_metadata, Tensor? s_aux) -> ()",
|
||||
&cpu_attention_with_kv_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
Reference in New Issue
Block a user