Support bfloat16 data type (#54)

This commit is contained in:
Woosuk Kwon
2023-05-03 14:09:44 -07:00
committed by GitHub
parent 436e523bf1
commit e070829ae8
12 changed files with 455 additions and 53 deletions

View File

@@ -34,7 +34,9 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"silu_and_mul_kernel",
[&] {