Support bfloat16 data type (#54)
This commit is contained in:
@@ -14,14 +14,16 @@ void swap_blocks(
|
||||
torch::Device dst_device = dst.device();
|
||||
cudaMemcpyKind memcpy_type;
|
||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||
assert(src_device.index() == dst_device.index());
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
memcpy_type = cudaMemcpyDeviceToDevice;
|
||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||
memcpy_type = cudaMemcpyDeviceToHost;
|
||||
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
||||
memcpy_type = cudaMemcpyHostToDevice;
|
||||
} else {
|
||||
assert(false);
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
@@ -29,6 +31,7 @@ void swap_blocks(
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
int64_t dst_block_number = pair.second;
|
||||
@@ -122,7 +125,9 @@ void copy_blocks(
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
@@ -176,6 +181,50 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
|
||||
// Grid: (num_blocks, block_size).
|
||||
template<typename scalar_t>
|
||||
__global__ void gather_cached_kv_kernel(
|
||||
@@ -296,45 +345,6 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
|
||||
} // namespace cacheflow
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
||||
@@ -354,7 +364,9 @@ void gather_cached_kv(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
|
||||
Reference in New Issue
Block a user