Add query stride to multi_query_cached_kv_attention & Add kernel benchmark script (#27)
* Add query stride to multi_query_cached_kv_attention * Add kernel benchmark script
This commit is contained in:
@@ -271,7 +271,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
||||
const float scale,
|
||||
const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int context_len,
|
||||
const int max_num_blocks_per_seq) {
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
@@ -302,7 +303,8 @@ __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||
// th vectors of the query, and so on.
|
||||
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||
@@ -514,7 +516,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_prompts]
|
||||
const int max_num_blocks_per_seq) {
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int prompt_idx = seq_prompt_mapping[seq_idx];
|
||||
const int seq_start_idx = cu_query_lens[prompt_idx];
|
||||
@@ -532,7 +535,8 @@ __global__ void multi_query_cached_kv_attention_kernel(
|
||||
scale,
|
||||
block_table,
|
||||
context_len,
|
||||
max_num_blocks_per_seq);
|
||||
max_num_blocks_per_seq,
|
||||
q_stride);
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
@@ -696,7 +700,8 @@ void single_query_cached_kv_attention(
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq);
|
||||
max_num_blocks_per_seq, \
|
||||
query_stride);
|
||||
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
@@ -719,6 +724,7 @@ void multi_query_cached_kv_attention_launcher(
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int query_stride = query.stride(0);
|
||||
|
||||
int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
|
||||
int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
|
||||
|
||||
Reference in New Issue
Block a user