Improve setup script & Add a guard for bfloat16 kernels (#130)
This commit is contained in:
@@ -458,10 +458,8 @@ void single_query_cached_kv_attention(
|
||||
// TODO(woosuk): Support FP32.
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
#ifdef ENABLE_BF16
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user