Support bfloat16 data type (#54)
This commit is contained in:
@@ -64,7 +64,9 @@ void rotary_embedding_neox(
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
[&] {
|
||||
|
||||
Reference in New Issue
Block a user