Add Falcon support (new) (#592)
This commit is contained in:
@@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int stride,
|
||||
const int query_stride,
|
||||
const int key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
@@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * stride + head_idx * head_size;
|
||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
||||
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
@@ -39,13 +40,27 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const scalar_t q_y = query[token_head + y_index];
|
||||
query[out_x] = q_x * cos - q_y * sin;
|
||||
query[out_y] = q_y * cos + q_x * sin;
|
||||
}
|
||||
|
||||
if (head_idx < num_kv_heads) {
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
}
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,8 +77,8 @@ void rotary_embedding_neox(
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int num_kv_heads = key.size(1) / head_size;
|
||||
int stride = query.stride(0);
|
||||
TORCH_CHECK(stride == key.stride(0));
|
||||
int query_stride = query.stride(0);
|
||||
int key_stride = key.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
@@ -80,7 +95,8 @@ void rotary_embedding_neox(
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
stride,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
|
||||
Reference in New Issue
Block a user