Add support for LLaMA-2 (#505)
This commit is contained in:
@@ -7,11 +7,12 @@ template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
@@ -19,8 +20,8 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const int n = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * stride + head_idx * head_size;
|
||||
|
||||
@@ -39,10 +40,12 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
query[out_x] = q_x * cos - q_y * sin;
|
||||
query[out_y] = q_y * cos + q_x * sin;
|
||||
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
if (head_idx < num_kv_heads) {
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,13 +54,14 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||
{
|
||||
int num_tokens = query.size(0);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int num_kv_heads = key.size(1) / head_size;
|
||||
int stride = query.stride(0);
|
||||
TORCH_CHECK(stride == key.stride(0));
|
||||
|
||||
@@ -78,6 +82,7 @@ void rotary_embedding_neox(
|
||||
rot_dim,
|
||||
stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user