Support bfloat16 data type (#54)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "attention_dtypes.cuh"
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
@@ -438,9 +438,13 @@ void single_query_cached_kv_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len) {
|
||||
// TODO(woosuk): Support FP32 and BF16.
|
||||
// TODO(woosuk): Support FP32.
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
#ifdef ENABLE_BF16
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user