Implement single_query_cached_kv_attention kernel (#3)

This commit is contained in:
Woosuk Kwon
2023-03-01 15:02:19 -08:00
committed by GitHub
parent cbf8779afa
commit 0deacbce6e
12 changed files with 2140 additions and 60 deletions

View File

@@ -48,7 +48,7 @@ __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, block_size, head_size]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const int num_heads,
const int head_size,
@@ -73,10 +73,10 @@ __global__ void reshape_and_cache_kernel(
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int tgt_value_idx = block_idx * num_heads * block_size * head_size
+ head_idx * block_size * head_size
+ block_offset * head_size
+ head_offset;
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
}