Make key optional for rotary embedding (#17566)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -21,6 +21,7 @@ SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
USE_KEY = [True, False]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
@@ -46,6 +47,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
@@ -58,6 +60,7 @@ def test_rotary_embedding(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -74,7 +77,7 @@ def test_rotary_embedding(
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
@@ -85,10 +88,14 @@ def test_rotary_embedding(
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@@ -101,6 +108,7 @@ def test_rotary_embedding(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
@@ -113,6 +121,7 @@ def test_batched_rotary_embedding(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -129,7 +138,7 @@ def test_batched_rotary_embedding(
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
@@ -145,10 +154,14 @@ def test_batched_rotary_embedding(
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@@ -160,6 +173,7 @@ def test_batched_rotary_embedding(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("use_key", USE_KEY)
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding_multi_lora(
|
||||
is_neox_style: bool,
|
||||
@@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
@@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
@@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user