Support bfloat16 data type (#54)

This commit is contained in:
Woosuk Kwon
2023-05-03 14:09:44 -07:00
committed by GitHub
parent 436e523bf1
commit e070829ae8
12 changed files with 455 additions and 53 deletions

View File

@@ -64,7 +64,9 @@ void rotary_embedding_neox(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
query.scalar_type(),
"rotary_embedding_neox",
[&] {