Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)

* Add query stride to multi_query_cached_kv_attention

* Add kernel benchmark script
This commit is contained in:
Woosuk Kwon
2023-04-08 13:36:09 -07:00
committed by GitHub
parent 0f40557af6
commit c267b1a02c
3 changed files with 181 additions and 8 deletions

View File

@@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
const float scale,
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
const int context_len,
const int max_num_blocks_per_seq) {
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
@@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
@@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_prompts]
const int max_num_blocks_per_seq) {
const int max_num_blocks_per_seq,
const int q_stride) {
const int seq_idx = blockIdx.y;
const int prompt_idx = seq_prompt_mapping[seq_idx];
const int seq_start_idx = cu_query_lens[prompt_idx];
@@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
scale,
block_table,
context_len,
max_num_blocks_per_seq);
max_num_blocks_per_seq,
q_stride);
}
} // namespace cacheflow
@@ -696,7 +700,8 @@ void single_query_cached_kv_attention(
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq);
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
@@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher(
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();