Make key optional for rotary embedding (#17566)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-05-07 00:11:46 -07:00
committed by GitHub
parent 324a3119b0
commit 98c89e16ff
10 changed files with 221 additions and 151 deletions

View File

@@ -9,7 +9,8 @@ void rotary_embedding_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -85,10 +86,13 @@ void rotary_embedding_impl(
compute_loop(token_head, cache_ptr, query);
}
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
if (key != nullptr) {
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
}
}
}
@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
}
}
if (key == nullptr) {
return;
}
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
}; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t key_stride = key.stride(-2);
int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES(
@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
if (is_neox) {
rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
} else {
rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
}
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)

View File

@@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);

View File

@@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rot_dim,
std::optional<torch::Tensor> key,
int64_t head_size, torch::Tensor& cos_sin_cache,
bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

View File

@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
if (key != nullptr) {
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}
}
@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@@ -127,10 +132,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std::optional<torch::Tensor> key,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
@@ -138,40 +145,40 @@ void rotary_embedding(
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
TORCH_CHECK(query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
(!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -181,15 +188,16 @@ void rotary_embedding(
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
num_heads, num_kv_heads, head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size);
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
}
});
}
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std::optional<torch::Tensor>
key, // null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim,
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets");
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
TORCH_CHECK(query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
(!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
vllm::batched_rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
}

View File

@@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
@@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");

View File

@@ -21,6 +21,7 @@ SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
USE_KEY = [True, False]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
@@ -46,6 +47,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
@@ -58,6 +60,7 @@ def test_rotary_embedding(
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
) -> None:
@@ -74,7 +77,7 @@ def test_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
@@ -85,10 +88,14 @@ def test_rotary_embedding(
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -101,6 +108,7 @@ def test_rotary_embedding(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
@@ -113,6 +121,7 @@ def test_batched_rotary_embedding(
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
) -> None:
@@ -129,7 +138,7 @@ def test_batched_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
@@ -145,10 +154,14 @@ def test_batched_rotary_embedding(
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -160,6 +173,7 @@ def test_batched_rotary_embedding(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
is_neox_style: bool,
@@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora(
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
) -> None:
@@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora(
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
offset_map = torch.tensor(
list(
@@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora(
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@torch.inference_mode()

View File

@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def rotary_embedding_opcheck(rot,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None):
cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype)
@@ -37,9 +37,10 @@ def rotary_embedding_opcheck(rot,
@pytest.mark.parametrize("rotary_dim", [32])
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len):
seq_len, use_key):
batch_size = 1
base = 10000
num_heads = 7
@@ -54,7 +55,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
num_heads * head_size,
dtype=torch.float32,
device=device)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,

View File

@@ -11,14 +11,16 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len", [
(16, False, 32, 32, 1024),
(16, False, 32, 128, 1024),
(16, True, 32, 32, 1024),
(16, True, 32, 128, 1024),
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [
(16, False, 32, 32, 1024, True),
(16, False, 32, 128, 1024, True),
(16, True, 32, 32, 1024, True),
(16, True, 32, 128, 1024, True),
(16, False, 32, 128, 1024, False),
(16, True, 32, 128, 1024, False),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len):
head_size, seq_len, use_key):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
@@ -40,19 +42,26 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device))
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
key.to(device=device) if key is not None else None)
if use_key:
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_key.cpu(),
ref_key,
atol=1e-2,
rtol=1e-2)
else:
assert out_key is None, "expected returned key to be None"
assert out_query.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(out_key.cpu(), ref_key, atol=1e-2, rtol=1e-2)

View File

@@ -153,34 +153,36 @@ def merge_attn_states(output: torch.Tensor,
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
key_contiguous = key.contiguous() if key is not None else None
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
head_size, cos_sin_cache, is_neox)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
if key is not None:
key.copy_(key_contiguous)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
key: Optional[torch.Tensor], head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
key_contiguous = key.contiguous() if key is not None else None
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
key_contiguous, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
if key is not None:
key.copy_(key_contiguous)
# layer norm ops

View File

@@ -138,9 +138,9 @@ class RotaryEmbedding(CustomOp):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
@@ -157,22 +157,24 @@ class RotaryEmbedding(CustomOp):
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb_torch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm import _custom_ops as ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
@@ -198,32 +200,39 @@ class RotaryEmbedding(CustomOp):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
if key is None:
# XPU kernel doesn't support key=None so fall back to native impl
# TODO(sarckk): add support for optional key in
# ipex.llm.functional.rotary_embedding_batched
return self.forward_native(positions, query, key, offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim, offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_hpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
if offsets is not None:
@@ -265,21 +274,23 @@ class RotaryEmbedding(CustomOp):
rope_mode)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0,
rope_mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_neuron(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def _apply_rotary_emb_neuron(
x: torch.Tensor,
@@ -319,14 +330,16 @@ class RotaryEmbedding(CustomOp):
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
if self.rotary_dim == self.head_size:
query = _apply_rotary_emb(query, cos, sin, self.is_neox_style)
query = query.reshape(query_shape)
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
key = key.reshape(key_shape)
if key is not None:
key = _apply_rotary_emb(key, cos, sin, self.is_neox_style)
key = key.reshape(key_shape)
else:
head_size = query.shape[-1]
query_reshaped = query.view(-1, head_size)
@@ -339,14 +352,15 @@ class RotaryEmbedding(CustomOp):
query = torch.cat((query_rot, query_pass),
dim=-1).reshape(query_shape)
key_reshaped = key.view(-1, head_size)
key_pass = key_reshaped[:, self.rotary_dim:].view(
*key.shape[:-1], head_size - self.rotary_dim)
key_rot = key_reshaped[:, :self.rotary_dim].view(
*key.shape[:-1], self.rotary_dim)
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
if key is not None:
key_reshaped = key.view(-1, head_size)
key_pass = key_reshaped[:, self.rotary_dim:].view(
*key.shape[:-1], head_size - self.rotary_dim)
key_rot = key_reshaped[:, :self.rotary_dim].view(
*key.shape[:-1], self.rotary_dim)
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
@@ -672,9 +686,10 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
@@ -782,10 +797,11 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
@@ -912,8 +928,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
key: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
@@ -957,8 +974,8 @@ class MRotaryEmbedding(RotaryEmbedding):
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
key: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().
Args:
@@ -969,6 +986,7 @@ class MRotaryEmbedding(RotaryEmbedding):
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]