[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from itertools import accumulate, product
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -24,7 +24,21 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads * head_size)
|
||||
|
||||
|
||||
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size)
|
||||
|
||||
|
||||
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@@ -36,6 +50,7 @@ CUDA_DEVICES = [
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
@@ -58,10 +73,8 @@ def test_rotary_embedding(
|
||||
rope = rope.to(dtype=dtype)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
@@ -80,6 +93,7 @@ def test_rotary_embedding(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@@ -91,6 +105,7 @@ def test_rotary_embedding(
|
||||
@torch.inference_mode()
|
||||
def test_batched_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
@@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
|
||||
rope = rope.to(dtype=dtype)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
|
||||
query = torch.randn(query_shape, dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
|
||||
Reference in New Issue
Block a user