Make key optional for rotary embedding (#17566)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin
2025-05-07 00:11:46 -07:00
committed by GitHub
parent 324a3119b0
commit 98c89e16ff
10 changed files with 221 additions and 151 deletions

View File

@@ -21,6 +21,7 @@ SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
USE_KEY = [True, False]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
@@ -46,6 +47,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
@@ -58,6 +60,7 @@ def test_rotary_embedding(
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
) -> None:
@@ -74,7 +77,7 @@ def test_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
@@ -85,10 +88,14 @@ def test_rotary_embedding(
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -101,6 +108,7 @@ def test_rotary_embedding(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
@@ -113,6 +121,7 @@ def test_batched_rotary_embedding(
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
) -> None:
@@ -129,7 +138,7 @@ def test_batched_rotary_embedding(
positions = torch.randint(0, max_position, (batch_size, seq_len))
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
@@ -145,10 +154,14 @@ def test_batched_rotary_embedding(
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@@ -160,6 +173,7 @@ def test_batched_rotary_embedding(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("use_key", USE_KEY)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
is_neox_style: bool,
@@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora(
dtype: torch.dtype,
seed: int,
device: str,
use_key: bool,
max_position: int = 8192,
base: int = 10000,
) -> None:
@@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora(
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
key = torch.randn_like(query) if use_key else None
offset_map = torch.tensor(
list(
@@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora(
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
if use_key:
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
else:
assert ref_key is None and out_key is None, \
"expected returned key to be None"
@torch.inference_mode()