Make key optional for rotary embedding (#17566)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -9,7 +9,8 @@ void rotary_embedding_impl(
|
||||
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||
/// head_size] or [num_tokens, num_heads,
|
||||
/// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
scalar_t* __restrict__ key, // nullptr (optional) or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
@@ -85,10 +86,13 @@ void rotary_embedding_impl(
|
||||
compute_loop(token_head, cache_ptr, query);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
const int head_idx = i;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
compute_loop(token_head, cache_ptr, key);
|
||||
if (key != nullptr) {
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
const int head_idx = i;
|
||||
const int64_t token_head =
|
||||
token_idx * key_stride + head_idx * head_size;
|
||||
compute_loop(token_head, cache_ptr, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
|
||||
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
|
||||
/// head_size] or [num_tokens, num_heads,
|
||||
/// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
scalar_t* __restrict__ key, // nullptr (optional) or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
|
||||
}
|
||||
}
|
||||
|
||||
if (key == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
|
||||
}; // namespace
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int64_t head_size,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||
int num_tokens = positions.numel();
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t key_stride = key.stride(-2);
|
||||
int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
|
||||
int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
if (is_neox) {
|
||||
rotary_embedding_impl(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size, num_tokens);
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
|
||||
} else {
|
||||
rotary_embedding_gptj_impl(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size, num_tokens);
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
|
||||
}
|
||||
|
||||
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
|
||||
|
||||
@@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
|
||||
@@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
|
||||
std::optional<torch::Tensor> residual);
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int64_t head_size,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox,
|
||||
int64_t rot_dim,
|
||||
std::optional<torch::Tensor> key,
|
||||
int64_t head_size, torch::Tensor& cos_sin_cache,
|
||||
bool is_neox, int64_t rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
scalar_t* __restrict__ key, // nullptr or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* cache_ptr, const int head_size, const int num_heads,
|
||||
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
|
||||
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
if (key != nullptr) {
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
scalar_t* __restrict__ key, // nullptr or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
|
||||
scalar_t* __restrict__ key, // nullptr or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
@@ -127,10 +132,12 @@ void rotary_embedding(
|
||||
// [num_tokens, num_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
std::optional<torch::Tensor> key,
|
||||
// null or
|
||||
// [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
@@ -138,40 +145,40 @@ void rotary_embedding(
|
||||
int64_t num_tokens = positions.numel();
|
||||
int positions_ndim = positions.dim();
|
||||
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key.
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
TORCH_CHECK(query.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
key.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
key.size(1) == positions.size(1),
|
||||
(!key.has_value() || key->size(1) == positions.size(1)),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
// Make sure head_size is valid for query and key
|
||||
// hidden_size = num_heads * head_size
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have consistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
@@ -181,15 +188,16 @@ void rotary_embedding(
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
|
||||
query_stride, key_stride, num_heads, num_kv_heads, head_size);
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
|
||||
num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size);
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
|
||||
// [num_tokens, num_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
std::optional<torch::Tensor>
|
||||
key, // null or
|
||||
// [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox, int64_t rot_dim,
|
||||
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
|
||||
"cos_sin_cache_offsets");
|
||||
|
||||
int positions_ndim = positions.dim();
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key.
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
TORCH_CHECK(query.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
key.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
key.size(1) == positions.size(1),
|
||||
(!key.has_value() || key->size(1) == positions.size(1)),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
// Make sure head_size is valid for query and key
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have concistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
|
||||
@@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||
|
||||
@@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// (supports multiple loras).
|
||||
ops.def(
|
||||
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox,"
|
||||
" int rot_dim,"
|
||||
" Tensor cos_sin_cache_offsets) -> ()");
|
||||
|
||||
@@ -21,6 +21,7 @@ SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
USE_KEY = [True, False]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
@@ -46,6 +47,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
@@ -58,6 +60,7 @@ def test_rotary_embedding(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -74,7 +77,7 @@ def test_rotary_embedding(
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
@@ -85,10 +88,14 @@ def test_rotary_embedding(
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@@ -101,6 +108,7 @@ def test_rotary_embedding(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
@@ -113,6 +121,7 @@ def test_batched_rotary_embedding(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -129,7 +138,7 @@ def test_batched_rotary_embedding(
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
@@ -145,10 +154,14 @@ def test_batched_rotary_embedding(
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@@ -160,6 +173,7 @@ def test_batched_rotary_embedding(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding_multi_lora(
|
||||
is_neox_style: bool,
|
||||
@@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
@@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
def rotary_embedding_opcheck(rot,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None):
|
||||
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||
|
||||
@@ -37,9 +37,10 @@ def rotary_embedding_opcheck(rot,
|
||||
@pytest.mark.parametrize("rotary_dim", [32])
|
||||
@pytest.mark.parametrize("head_size", [32, 108])
|
||||
@pytest.mark.parametrize("seq_len", [11, 1024])
|
||||
@pytest.mark.parametrize("use_key", [True, False])
|
||||
def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
seq_len):
|
||||
seq_len, use_key):
|
||||
batch_size = 1
|
||||
base = 10000
|
||||
num_heads = 7
|
||||
@@ -54,7 +55,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
|
||||
num_heads * head_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
rotary_embedding_opcheck(rot, positions, query, key)
|
||||
offsets = torch.zeros(batch_size * seq_len,
|
||||
|
||||
@@ -11,14 +11,16 @@ from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
|
||||
(16, False, 32, 32, 1024),
|
||||
(16, False, 32, 128, 1024),
|
||||
(16, True, 32, 32, 1024),
|
||||
(16, True, 32, 128, 1024),
|
||||
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [
|
||||
(16, False, 32, 32, 1024, True),
|
||||
(16, False, 32, 128, 1024, True),
|
||||
(16, True, 32, 32, 1024, True),
|
||||
(16, True, 32, 128, 1024, True),
|
||||
(16, False, 32, 128, 1024, False),
|
||||
(16, True, 32, 128, 1024, False),
|
||||
])
|
||||
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
|
||||
head_size, seq_len):
|
||||
head_size, seq_len, use_key):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
device = xm.xla_device()
|
||||
@@ -40,19 +42,26 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
|
||||
num_heads * head_size,
|
||||
dtype=torch.float32,
|
||||
device="cpu")
|
||||
key = torch.randn_like(query)
|
||||
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
assert positions.is_cpu, \
|
||||
"reference input tensor is expected to be CPU tensor."
|
||||
ref_query, ref_key = rot.to(device="cpu").forward_native(
|
||||
positions, query, key)
|
||||
out_query, out_key = rot.to(device=device).forward_neuron(
|
||||
positions.to(device=device), query.to(device=device),
|
||||
key.to(device=device))
|
||||
assert out_query.is_xla and out_key.is_xla, \
|
||||
"output tensor is expected to be XLA tensor"
|
||||
key.to(device=device) if key is not None else None)
|
||||
if use_key:
|
||||
assert out_query.is_xla and out_key.is_xla, \
|
||||
"output tensor is expected to be XLA tensor"
|
||||
torch.testing.assert_close(out_key.cpu(),
|
||||
ref_key,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
else:
|
||||
assert out_key is None, "expected returned key to be None"
|
||||
assert out_query.is_xla, \
|
||||
"output tensor is expected to be XLA tensor"
|
||||
torch.testing.assert_close(out_query.cpu(),
|
||||
ref_query,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@@ -153,34 +153,36 @@ def merge_attn_states(output: torch.Tensor,
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
||||
query_contiguous = query.contiguous()
|
||||
key_contiguous = key.contiguous()
|
||||
key_contiguous = key.contiguous() if key is not None else None
|
||||
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
|
||||
head_size, cos_sin_cache, is_neox)
|
||||
query.copy_(query_contiguous)
|
||||
key.copy_(key_contiguous)
|
||||
if key is not None:
|
||||
key.copy_(key_contiguous)
|
||||
|
||||
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
key: Optional[torch.Tensor], head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
|
||||
query_contiguous = query.contiguous()
|
||||
key_contiguous = key.contiguous()
|
||||
key_contiguous = key.contiguous() if key is not None else None
|
||||
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
|
||||
key_contiguous, head_size,
|
||||
cos_sin_cache, is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
query.copy_(query_contiguous)
|
||||
key.copy_(key_contiguous)
|
||||
if key is not None:
|
||||
key.copy_(key_contiguous)
|
||||
|
||||
|
||||
# layer norm ops
|
||||
|
||||
@@ -138,9 +138,9 @@ class RotaryEmbedding(CustomOp):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
@@ -157,22 +157,24 @@ class RotaryEmbedding(CustomOp):
|
||||
self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
@@ -198,32 +200,39 @@ class RotaryEmbedding(CustomOp):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
dtype=query.dtype)
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style, self.rotary_dim,
|
||||
offsets)
|
||||
if key is None:
|
||||
# XPU kernel doesn't support key=None so fall back to native impl
|
||||
# TODO(sarckk): add support for optional key in
|
||||
# ipex.llm.functional.rotary_embedding_batched
|
||||
return self.forward_native(positions, query, key, offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
if offsets is not None:
|
||||
ops.batched_rotary_embedding(positions, query, key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style,
|
||||
self.rotary_dim, offsets)
|
||||
else:
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
|
||||
if offsets is not None:
|
||||
@@ -265,21 +274,23 @@ class RotaryEmbedding(CustomOp):
|
||||
rope_mode)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0,
|
||||
rope_mode)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_neuron(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
def _apply_rotary_emb_neuron(
|
||||
x: torch.Tensor,
|
||||
@@ -319,14 +330,16 @@ class RotaryEmbedding(CustomOp):
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
if key is not None:
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
|
||||
if self.rotary_dim == self.head_size:
|
||||
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
|
||||
query = query.reshape(query_shape)
|
||||
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
|
||||
key = key.reshape(key_shape)
|
||||
if key is not None:
|
||||
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
|
||||
key = key.reshape(key_shape)
|
||||
else:
|
||||
head_size = query.shape[-1]
|
||||
query_reshaped = query.view(-1, head_size)
|
||||
@@ -339,14 +352,15 @@ class RotaryEmbedding(CustomOp):
|
||||
query = torch.cat((query_rot, query_pass),
|
||||
dim=-1).reshape(query_shape)
|
||||
|
||||
key_reshaped = key.view(-1, head_size)
|
||||
key_pass = key_reshaped[:, self.rotary_dim:].view(
|
||||
*key.shape[:-1], head_size - self.rotary_dim)
|
||||
key_rot = key_reshaped[:, :self.rotary_dim].view(
|
||||
*key.shape[:-1], self.rotary_dim)
|
||||
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
if key is not None:
|
||||
key_reshaped = key.view(-1, head_size)
|
||||
key_pass = key_reshaped[:, self.rotary_dim:].view(
|
||||
*key.shape[:-1], head_size - self.rotary_dim)
|
||||
key_rot = key_reshaped[:, :self.rotary_dim].view(
|
||||
*key.shape[:-1], self.rotary_dim)
|
||||
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
@@ -672,9 +686,10 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
@@ -782,10 +797,11 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
@@ -912,8 +928,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
query_ = torch.view_as_complex(query.float().reshape(
|
||||
*query.shape[:-1], -1, 2))
|
||||
@@ -957,8 +974,8 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
@@ -969,6 +986,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
|
||||
Reference in New Issue
Block a user