Change scheduler & input tensor shape (#1381)

This commit is contained in:
Woosuk Kwon
2023-10-16 17:48:42 -07:00
committed by GitHub
parent 651c614aa4
commit c1376e0f82
13 changed files with 180 additions and 178 deletions

View File

@@ -37,9 +37,9 @@ inline __device__ void apply_rotary_embedding(
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int query_stride,
@@ -78,18 +78,18 @@ __global__ void rotary_embedding_kernel(
} // namespace vllm
void rotary_embedding(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.size(0);
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(1) / head_size;
int query_stride = query.stride(0);
int key_stride = key.stride(0);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int query_stride = query.stride(-2);
int key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));