Implement single_query_cached_kv_attention kernel (#3)
This commit is contained in:
@@ -48,7 +48,7 @@ __global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
@@ -73,10 +73,10 @@ __global__ void reshape_and_cache_kernel(
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int tgt_value_idx = block_idx * num_heads * block_size * head_size
|
||||
+ head_idx * block_size * head_size
|
||||
+ block_offset * head_size
|
||||
+ head_offset;
|
||||
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
|
||||
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user