[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#5047)

This commit is contained in:
bnellnm
2024-06-09 16:23:30 -04:00
committed by GitHub
parent 5d7e3d0176
commit 5467ac3196
55 changed files with 833 additions and 451 deletions

View File

@@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl(
}; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);