[BugFix] Implement RoPE for GPT-J (#941)

This commit is contained in:
Woosuk Kwon
2023-09-06 11:54:33 +09:00
committed by GitHub
parent c9927c1a6a
commit 320a622ec4
5 changed files with 122 additions and 72 deletions

View File

@@ -1,15 +1,16 @@
#include <torch/extension.h>
void rotary_embedding_neox(
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
torch::Tensor& cos_sin_cache,
bool is_neox);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
}