Add support for GPT-NeoX (Pythia) (#50)
This commit is contained in:
@@ -4,6 +4,7 @@ void rotary_embedding_neox(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
@@ -8,16 +8,17 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int stride,
|
||||
const int num_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = head_size / 2;
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const int n = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
@@ -51,16 +52,17 @@ void rotary_embedding_neox(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& cos_sin_cache) // [max_position, head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||
{
|
||||
int num_tokens = query.size(0);
|
||||
int head_size = cos_sin_cache.size(1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int stride = query.stride(0);
|
||||
TORCH_CHECK(stride == key.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size / 2, 512));
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
query.scalar_type(),
|
||||
@@ -71,6 +73,7 @@ void rotary_embedding_neox(
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
stride,
|
||||
num_heads,
|
||||
head_size);
|
||||
|
||||
Reference in New Issue
Block a user