Support bfloat16 data type (#54)

This commit is contained in:
Woosuk Kwon
2023-05-03 14:09:44 -07:00
committed by GitHub
parent 436e523bf1
commit e070829ae8
12 changed files with 455 additions and 53 deletions

View File

@@ -14,14 +14,16 @@ void swap_blocks(
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
assert(src_device.index() == dst_device.index());
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
assert(false);
TORCH_CHECK(false, "Invalid device combination");
}
void *src_ptr = src.data_ptr();
@@ -29,6 +31,7 @@ void swap_blocks(
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
@@ -122,7 +125,9 @@ void copy_blocks(
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
@@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
}
}
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
namespace cacheflow {
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
@@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
@@ -354,7 +364,9 @@ void gather_cached_kv(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {