[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin
2024-05-22 03:18:41 -04:00
committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
64 changed files with 6398 additions and 6790 deletions

View File

@@ -2,7 +2,8 @@
namespace {
template <typename scalar_t> struct KernelVecType {
template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void;
using q_vec_type = void;
using k_load_vec_type = void;
@@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void;
};
template <> struct KernelVecType<float> {
template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16;
@@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
};
#ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> {
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32;
@@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
template <> struct KernelVecType<c10::BFloat16> {
template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
@@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
#endif
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
const int capacity) {
T max = data[0];
for (int i = 1; i < size; ++i) {
@@ -67,10 +71,11 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
}
template <typename T>
FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index,
const int seq_len) {
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
const int capacity,
const float alibi_slope,
const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0];
for (int i = 1; i < size; ++i) {
@@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
}
template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
const int size) {
T max = max_data[0];
for (int i = 1; i < size; ++i) {
@@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
const scalar_t *__restrict__ k_block,
float *__restrict__ logits, float scale,
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t* __restrict__ k_block,
float* __restrict__ logits, float scale,
const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
@@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc_t &&acc) {
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t&& acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM);
@@ -209,27 +214,27 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
});
}
}; // namespace
}; // namespace
// Paged attention v1
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl {
static void
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs,
// max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
@@ -243,32 +248,31 @@ struct paged_attention_v1_impl {
size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded]
float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx];
const int *seq_block_table =
const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num =
seq_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits =
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
float* __restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr =
const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits =
float* __restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@@ -282,8 +286,7 @@ struct paged_attention_v1_impl {
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len);
} else {
reduceSoftmax(thread_block_logits, seq_len,
block_num * BLOCK_SIZE);
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
}
// Compute value
@@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr =
scalar_t* __restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr =
const float* __restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr =
const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr =
const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -359,67 +362,66 @@ void paged_attention_v1_impl_launcher(
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr =
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>();
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
} // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache,
void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& key_cache, torch::Tensor& value_cache,
int num_kv_heads, float scale,
torch::Tensor &block_tables,
torch::Tensor &seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
torch::Tensor& block_tables, torch::Tensor& seq_lens,
int block_size, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
@@ -434,23 +436,24 @@ namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl {
static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs]
const int* __restrict__ block_tables, // [num_seqs,
// max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t);
@@ -468,8 +471,7 @@ struct paged_attention_v2_impl {
const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len)
continue;
if (start_token_idx >= seq_len) continue;
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
@@ -477,15 +479,14 @@ struct paged_attention_v2_impl {
const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx);
const int block_num =
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables +
const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
@@ -493,10 +494,10 @@ struct paged_attention_v2_impl {
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr =
const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits =
float* __restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@@ -510,13 +511,13 @@ struct paged_attention_v2_impl {
logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len);
} else {
max_and_sum = reduceSoftmax(logits, token_num,
block_num * BLOCK_SIZE);
max_and_sum =
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
}
auto &&[max_logit, exp_sum] = max_and_sum;
auto&& [max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr;
scalar_t* __restrict__ output_buffer = nullptr;
if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
@@ -538,13 +539,13 @@ struct paged_attention_v2_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr =
scalar_t* __restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr =
const float* __restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr =
const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@@ -555,7 +556,7 @@ struct paged_attention_v2_impl {
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr =
const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@@ -587,8 +588,7 @@ struct paged_attention_v2_impl {
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
if (partition_num == 1) continue;
reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions +
@@ -603,11 +603,11 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
// didn't align with 64 bytes
16; // Note: didn't align with the cacheline size, due to some
// HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums;
const float* __restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
@@ -616,17 +616,16 @@ struct paged_attention_v2_impl {
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
if (partition_num == 1) continue;
const float *__restrict__ seq_head_rescale_factors =
const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out =
const scalar_t* __restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output =
scalar_t* __restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group;
@@ -645,21 +644,21 @@ struct paged_attention_v2_impl {
}
};
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@@ -670,72 +669,72 @@ void paged_attention_v2_impl_launcher(
int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr =
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>();
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, \
max_seq_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
} // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &max_logits, torch::Tensor &tmp_out,
torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables,
torch::Tensor &seq_lens, int block_size,
void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads,
float scale, torch::Tensor& block_tables,
torch::Tensor& seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {