[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)

This commit is contained in:
Isotr0py
2025-02-07 00:46:13 +08:00
committed by GitHub
parent 1e57b1ee63
commit 85ac82d228
4 changed files with 115 additions and 57 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional
import pytest
import torch
@@ -24,7 +24,21 @@ CUDA_DEVICES = [
]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads * head_size)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -36,6 +50,7 @@ CUDA_DEVICES = [
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int,
seq_len: int,
num_heads: int,
@@ -58,10 +73,8 @@ def test_rotary_embedding(
rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first
@@ -80,6 +93,7 @@ def test_rotary_embedding(
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@@ -91,6 +105,7 @@ def test_rotary_embedding(
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int,
seq_len: int,
num_heads: int,
@@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first