Add support for GPT-NeoX (Pythia) (#50)

This commit is contained in:
Woosuk Kwon
2023-04-28 00:32:10 -07:00
committed by GitHub
parent aa50b17ca7
commit a96d63c21d
9 changed files with 436 additions and 71 deletions

View File

@@ -4,6 +4,7 @@ void rotary_embedding_neox(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

View File

@@ -8,16 +8,17 @@ __global__ void rotary_embedding_neox_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int stride,
const int num_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = head_size / 2;
const int embed_dim = rot_dim / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / embed_dim;
@@ -51,16 +52,17 @@ void rotary_embedding_neox(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& cos_sin_cache) // [max_position, head_size]
int head_size,
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
{
int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size / 2, 512));
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
query.scalar_type(),
@@ -71,6 +73,7 @@ void rotary_embedding_neox(
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
stride,
num_heads,
head_size);