Improve setup script & Add a guard for bfloat16 kernels (#130)

This commit is contained in:
Woosuk Kwon
2023-05-27 00:59:32 -07:00
committed by GitHub
parent 4a151dd453
commit d721168449
4 changed files with 90 additions and 16 deletions

View File

@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
// TODO(woosuk): Support FP32.
if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
#ifdef ENABLE_BF16
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
#endif
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}