Optimize data movement (#20)

This commit is contained in:
Woosuk Kwon
2023-04-02 00:30:17 -07:00
committed by GitHub
parent 1f01a18d39
commit 897cb2ae28
17 changed files with 275 additions and 135 deletions

View File

@@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel(
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
@@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel(
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int src_idx = token_idx * n + i;
const int src_key_idx = token_idx * key_stride + i;
const int src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
@@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
}
}
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping) {
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -140,6 +147,8 @@ void reshape_and_cache(
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,