Optimize data movement (#20)
This commit is contained in:
@@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq) {
|
||||
const int max_num_blocks_per_seq,
|
||||
const int q_stride) {
|
||||
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
@@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||
// th vectors of the query, and so on.
|
||||
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||
@@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq);
|
||||
max_num_blocks_per_seq, \
|
||||
query_stride);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
@@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher(
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int query_stride = query.stride(0);
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
@@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher(
|
||||
}
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
// TODO(woosuk): Support BF16.
|
||||
|
||||
Reference in New Issue
Block a user