[Kernel] Have rotary embeddings support tensors (#18046)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
return (batch_size, seq_len, num_heads * head_size)
|
||||
|
||||
|
||||
# For testing sliced tensors
|
||||
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size + 64)
|
||||
|
||||
|
||||
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size)
|
||||
|
||||
|
||||
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
||||
TENSORS_SHAPES_FN = [
|
||||
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@@ -79,6 +87,10 @@ def test_rotary_embedding(
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query) if use_key else None
|
||||
|
||||
# slice tensor if required, noop otherwise
|
||||
query = query[..., :head_size]
|
||||
key = key[..., :head_size] if use_key else None
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
|
||||
Reference in New Issue
Block a user